Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tabular regression pipeline #85

Merged
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9d50cb6
removed old supported_tasks dictionary from heads, added some docstri…
bastiscode Feb 1, 2021
b7c8773
removed old supported_tasks attribute and updated doc strings in base…
bastiscode Feb 1, 2021
725faf2
removed old supported_tasks attribute from network backbones
bastiscode Feb 1, 2021
740a604
put time series backbones in separate files, add doc strings and refa…
bastiscode Feb 1, 2021
b727016
split image networks into separate files, add doc strings and refacto…
bastiscode Feb 1, 2021
bc77ca3
fix typo
bastiscode Feb 1, 2021
f8de549
add an intial simple backbone test similar to the network head test
bastiscode Feb 1, 2021
480b8ea
fix flake8
bastiscode Feb 1, 2021
f461c7e
fixed imports in backbones and heads
bastiscode Feb 2, 2021
cab8f83
added new network backbone and head tests
bastiscode Feb 2, 2021
ab2f5e9
enabled tests for adding custom backbones and heads, added required p…
bastiscode Feb 2, 2021
41e5974
adding tabular regression pipeline
bastiscode Feb 4, 2021
d7037ac
upstream changes
bastiscode Feb 4, 2021
1b5fc46
fix flake8
bastiscode Feb 1, 2021
7987c86
adding tabular regression pipeline
bastiscode Feb 4, 2021
75f49b1
merged remote
bastiscode Feb 4, 2021
1dbf53f
fix flake8
bastiscode Feb 4, 2021
1726105
fix regression test
bastiscode Feb 4, 2021
eb02feb
fix indentation and comments, undo change in base network
bastiscode Feb 9, 2021
34e6bf4
pipeline fitting tests now check the expected output shape dynamicall…
bastiscode Feb 9, 2021
1f0444f
refactored trainer tests, added trainer test for regression
bastiscode Feb 9, 2021
7efd048
remove regression from mixup unitest
bastiscode Feb 9, 2021
2aebcee
use pandas unique instead of numpy
bastiscode Feb 10, 2021
6668509
[IMPORTANT] added proper target casting based on task type to base tr…
bastiscode Feb 10, 2021
29bbdef
adding tabular regression task to api
bastiscode Feb 10, 2021
fb9e175
adding tabular regression example, some small fixes
bastiscode Feb 10, 2021
04521f8
new/more tests for tabular regression
bastiscode Feb 10, 2021
071f3b8
Merge branch 'refactor_development' into refactor_development
bastiscode Feb 10, 2021
8833bc6
fix mypy and flake8 errors from merge
bastiscode Feb 10, 2021
73ccc7c
fix issues with new weighted loss and regression tasks
bastiscode Feb 10, 2021
760296e
change tabular column transformer to use net fit_dictionary_tabular f…
bastiscode Feb 10, 2021
506e55d
fixing tests, replaced num_classes with output_shape
bastiscode Feb 10, 2021
5c46da7
Merge branch 'refactor_development' of github.com:automl/Auto-PyTorch…
bastiscode Feb 15, 2021
85f9995
fixes after merge
bastiscode Feb 15, 2021
1a507e6
adding voting regressor wrapper
bastiscode Feb 16, 2021
44f1980
fix mypy and flake
bastiscode Feb 16, 2021
5a19140
updated example
bastiscode Feb 16, 2021
7d7da2e
lower r2 target
bastiscode Feb 16, 2021
17c2086
address comments
bastiscode Feb 17, 2021
a29dbee
increasing timeout
bastiscode Feb 17, 2021
927fe87
increase number of labels in test_losses because it occasionally fail…
bastiscode Feb 18, 2021
5d582dc
lower regression lr in score test until seeding properly works
bastiscode Feb 18, 2021
07e75f6
fix randomization in feature validator test
bastiscode Feb 18, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Run tests
run: |
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
python -m pytest --durations=20 --timeout=300 --timeout-method=thread -v $codecov test
python -m pytest --durations=20 --timeout=500 --timeout-method=thread -v $codecov test
- name: Check for files left behind by test
if: ${{ always() }}
run: |
Expand Down
20 changes: 5 additions & 15 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def send_warnings_to_log(
with warnings.catch_warnings():
warnings.showwarning = send_warnings_to_log
if task in REGRESSION_TASKS:
prediction = pipeline.predict(X_, batch_size=batch_size)
# Voting regressor does not support batch size
prediction = pipeline.predict(X_)
else:
# Voting classifier predict proba does not support batch size
prediction = pipeline.predict_proba(X_)
Expand Down Expand Up @@ -161,7 +162,7 @@ def __init__(
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
)
self.task_type = task_type
self.task_type = task_type or ""
self._stopwatch = StopWatch()

self.pipeline_options = replace_string_bool_to_bool(json.load(open(
Expand Down Expand Up @@ -789,7 +790,7 @@ def _search(
max_models_on_disc=self.max_models_on_disc,
seed=self.seed,
max_iterations=None,
read_at_most=np.inf,
read_at_most=sys.maxsize,
ensemble_memory_limit=self._memory_limit,
random_state=self.seed,
precision=precision,
Expand Down Expand Up @@ -1050,7 +1051,7 @@ def predict(

all_predictions = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(_pipeline_predict)(
models[identifier], X_test, batch_size, self._logger, self.task_type
models[identifier], X_test, batch_size, self._logger, STRING_TO_TASK_TYPES[self.task_type]
)
for identifier in self.ensemble_.get_selected_model_identifiers()
)
Expand All @@ -1064,17 +1065,6 @@ def predict(

predictions = self.ensemble_.predict(all_predictions)

if self.task_type in REGRESSION_TASKS:
# Make sure prediction probabilities
# are within a valid range
# Individual models are checked in _pipeline_predict
if (
(predictions >= 0).all() and (predictions <= 1).all()
):
raise ValueError("For ensemble {}, prediction probability not within [0, 1]!".format(
self.ensemble_)
)

self._clean_logger()

return predictions
Expand Down
249 changes: 249 additions & 0 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
import os
import uuid
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np

import pandas as pd

from autoPyTorch.api.base_task import BaseTask
from autoPyTorch.constants import (
TABULAR_REGRESSION,
TASK_TYPES_TO_STRING
)
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates


class TabularRegressionTask(BaseTask):
"""
Tabular Regression API to the pipelines.
Args:
seed (int): seed to be used for reproducibility.
n_jobs (int), (default=1): number of consecutive processes to spawn.
logging_config (Optional[Dict]): specifies configuration
for logging, if None, it is loaded from the logging.yaml
ensemble_size (int), (default=50): Number of models added to the ensemble built by
Ensemble selection from libraries of models.
Models are drawn with replacement.
ensemble_nbest (int), (default=50): only consider the ensemble_nbest
models to build the ensemble
max_models_on_disc (int), (default=50): maximum number of models saved to disc.
Also, controls the size of the ensemble as any additional models will be deleted.
Must be greater than or equal to 1.
temporary_directory (str): folder to store configuration output and log file
output_directory (str): folder to store predictions for optional test set
delete_tmp_folder_after_terminate (bool): determines whether to delete the temporary directory,
when finished
include_components (Optional[Dict]): If None, all possible components are used.
Otherwise specifies set of components to use.
exclude_components (Optional[Dict]): If None, all possible components are used.
Otherwise specifies set of components not to use. Incompatible with include
components
"""

def __init__(
self,
seed: int = 1,
n_jobs: int = 1,
logging_config: Optional[Dict] = None,
ensemble_size: int = 50,
ensemble_nbest: int = 50,
max_models_on_disc: int = 50,
temporary_directory: Optional[str] = None,
output_directory: Optional[str] = None,
delete_tmp_folder_after_terminate: bool = True,
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
):
super().__init__(
seed=seed,
n_jobs=n_jobs,
logging_config=logging_config,
ensemble_size=ensemble_size,
ensemble_nbest=ensemble_nbest,
max_models_on_disc=max_models_on_disc,
temporary_directory=temporary_directory,
output_directory=output_directory,
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
include_components=include_components,
exclude_components=exclude_components,
backend=backend,
resampling_strategy=resampling_strategy,
resampling_strategy_args=resampling_strategy_args,
search_space_updates=search_space_updates,
task_type=TASK_TYPES_TO_STRING[TABULAR_REGRESSION],
)

def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
if not isinstance(dataset, TabularDataset):
raise ValueError("Dataset is incompatible for the given task,: {}".format(
type(dataset)
))
return {'task_type': dataset.task_type,
'output_type': dataset.output_type,
'issparse': dataset.issparse,
'numerical_columns': dataset.numerical_columns,
'categorical_columns': dataset.categorical_columns}

def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline:
return TabularRegressionPipeline(dataset_properties=dataset_properties)
bastiscode marked this conversation as resolved.
Show resolved Hide resolved

def search(self,
optimize_metric: str,
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
dataset_name: Optional[str] = None,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
func_eval_time_limit: int = 60,
traditional_per_total_budget: float = 0.1,
memory_limit: Optional[int] = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: List = [],
load_models: bool = True,
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.

Fit both optimizes the machine learning models and builds an ensemble out of them.
To disable ensembling, set ensemble_size==0.
using the optimizer.
Args:
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
A pair of features (X_train) and targets (y_train) used to fit a
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
be provided to track the generalization performance of each stage.
optimize_metric (str): name of the metric that is used to
evaluate a pipeline.
budget_type (Optional[str]):
Type of budget to be used when fitting the pipeline.
Either 'epochs' or 'runtime'. If not provided, uses
the default in the pipeline config ('epochs')
budget (Optional[float]):
Budget to fit a single run of the pipeline. If not
provided, uses the default in the pipeline config
total_walltime_limit (int), (default=100): Time limit
in seconds for the search of appropriate models.
By increasing this value, autopytorch has a higher
chance of finding better models.
func_eval_time_limit (int), (default=60): Time limit
for a single call to the machine learning model.
Model fitting will be terminated if the machine
learning algorithm runs over the time limit. Set
this value high enough so that typical machine
learning algorithms can be fit on the training
data.
traditional_per_total_budget (float), (default=0.1):
Percent of total walltime to be allocated for
running traditional classifiers.
memory_limit (Optional[int]), (default=4096): Memory
limit in MB for the machine learning algorithm. autopytorch
will stop fitting the machine learning algorithm if it tries
to allocate more than memory_limit MB. If None is provided,
no memory limit is set. In case of multi-processing, memory_limit
will be per job. This memory limit also applies to the ensemble
creation process.
smac_scenario_args (Optional[Dict]): Additional arguments inserted
into the scenario of SMAC. See the
[SMAC documentation] (https://automl.github.io/SMAC3/master/options.html?highlight=scenario#scenario)
get_smac_object_callback (Optional[Callable]): Callback function
to create an object of class
[smac.optimizer.smbo.SMBO](https://automl.github.io/SMAC3/master/apidoc/smac.optimizer.smbo.html).
The function must accept the arguments scenario_dict,
instances, num_params, runhistory, seed and ta. This is
an advanced feature. Use only if you are familiar with
[SMAC](https://automl.github.io/SMAC3/master/index.html).
all_supported_metrics (bool), (default=True): if True, all
metrics supporting current task will be calculated
for each pipeline and results will be available via cv_results
precision (int), (default=32): Numeric precision used when loading
ensemble data. Can be either '16', '32' or '64'.
disable_file_output (Union[bool, List]):
load_models (bool), (default=True): Whether to load the
models after fitting AutoPyTorch.

Returns:
self

"""
if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

# we have to create a logger for at this point for the validator
self._logger = self._get_logger(dataset_name)

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
self.InputValidator = TabularInputValidator(
is_classification=False,
logger_port=self._logger_port,
)

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

self.dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=self.InputValidator,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
)

return self._search(
dataset=self.dataset,
optimize_metric=optimize_metric,
budget_type=budget_type,
budget=budget,
total_walltime_limit=total_walltime_limit,
func_eval_time_limit=func_eval_time_limit,
traditional_per_total_budget=traditional_per_total_budget,
memory_limit=memory_limit,
smac_scenario_args=smac_scenario_args,
get_smac_object_callback=get_smac_object_callback,
all_supported_metrics=all_supported_metrics,
precision=precision,
disable_file_output=disable_file_output,
load_models=load_models,
)

def predict(
self,
X_test: np.ndarray,
batch_size: Optional[int] = None,
n_jobs: int = 1
) -> np.ndarray:
if self.InputValidator is None or not self.InputValidator._is_fitted:
raise ValueError("predict() is only supported after calling search. Kindly call first "
"the estimator fit() method.")

X_test = self.InputValidator.feature_validator.transform(X_test)
predicted_values = super().predict(X_test, batch_size=batch_size,
n_jobs=n_jobs)

# Allow to predict in the original domain -- that is, the user is not interested
# in our encoded values
return self.InputValidator.target_validator.inverse_transform(predicted_values)
16 changes: 10 additions & 6 deletions autoPyTorch/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torchvision

from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
from autoPyTorch.datasets.resampling_strategy import (
CROSS_VAL_FN,
CrossValTypes,
Expand Down Expand Up @@ -113,11 +114,15 @@ def __init__(
self.resampling_strategy_args = resampling_strategy_args
self.task_type: Optional[str] = None
self.issparse: bool = issparse(self.train_tensors[0])
self.input_shape: Tuple[int] = train_tensors[0].shape[1:]
self.num_classes: Optional[int] = None
if len(train_tensors) == 2 and train_tensors[1] is not None:
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]

if len(self.train_tensors) == 2 and self.train_tensors[1] is not None:
self.output_type: str = type_of_target(self.train_tensors[1])
self.output_shape: int = train_tensors[1].shape[1] if train_tensors[1].shape == 2 else 1

if STRING_TO_OUTPUT_TYPES[self.output_type] in CLASSIFICATION_OUTPUTS:
self.output_shape = len(np.unique(self.train_tensors[1]))
else:
self.output_shape = self.train_tensors[1].shape[-1] if self.train_tensors[1].ndim > 1 else 1

# TODO: Look for a criteria to define small enough to preprocess
self.is_small_preprocess = True
Expand Down Expand Up @@ -368,8 +373,7 @@ def get_dataset_properties(self, dataset_requirements: List[FitRequirement]) ->
'output_type': self.output_type,
'issparse': self.issparse,
'input_shape': self.input_shape,
'output_shape': self.output_shape,
'num_classes': self.num_classes,
'output_shape': self.output_shape
})
return dataset_properties

Expand Down
5 changes: 3 additions & 2 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sklearn.base import BaseEstimator
from sklearn.dummy import DummyClassifier, DummyRegressor
from sklearn.ensemble import VotingClassifier, VotingRegressor
from sklearn.ensemble import VotingClassifier

from smac.tae import StatusType

Expand All @@ -32,6 +32,7 @@
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.evaluation.utils import (
VotingRegressorWrapper,
convert_multioutput_multiclass_to_multilabel
)
from autoPyTorch.pipeline.base_pipeline import BasePipeline
Expand Down Expand Up @@ -513,7 +514,7 @@ def file_output(
if self.task_type in CLASSIFICATION_TASKS:
pipelines = VotingClassifier(estimators=None, voting='soft', )
else:
pipelines = VotingRegressor(estimators=None)
pipelines = VotingRegressorWrapper(estimators=None)
pipelines.estimators_ = self.pipelines
else:
pipelines = None
Expand Down
1 change: 1 addition & 0 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def _predict(self, pipeline: BaseEstimator,
self.y_valid)
else:
valid_pred = None

if self.X_test is not None:
test_pred = self.predict_function(self.X_test, pipeline,
self.y_train[train_indices])
Expand Down