Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,5 @@ jobs:
python examples/tabular/20_basics/example_tabular_regression.py
python examples/tabular/40_advanced/example_custom_configuration_space.py
python examples/tabular/40_advanced/example_resampling_strategy.py
python examples/tabular/40_advanced/example_single_configuration.py
python examples/example_image_classification.py
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Run tests
run: |
if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi
python -m pytest --durations=20 --timeout=600 --timeout-method=signal -v $codecov test
python -m pytest --forked --durations=20 --timeout=600 --timeout-method=signal -v $codecov test
- name: Check for files left behind by test
if: ${{ always() }}
run: |
Expand Down
314 changes: 250 additions & 64 deletions autoPyTorch/api/base_task.py

Large diffs are not rendered by default.

94 changes: 64 additions & 30 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,67 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An
'numerical_columns': dataset.numerical_columns,
'categorical_columns': dataset.categorical_columns}

def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline:
return TabularClassificationPipeline(dataset_properties=dataset_properties)
def build_pipeline(self, dataset_properties: Dict[str, Any],
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> TabularClassificationPipeline:
return TabularClassificationPipeline(dataset_properties=dataset_properties,
include=include_components,
exclude=exclude_components,
search_space_updates=search_space_updates)

def get_dataset(self,
X_train: Union[List, pd.DataFrame, np.ndarray],
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Union[List, pd.DataFrame, np.ndarray],
y_test: Union[List, pd.DataFrame, np.ndarray],
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
return_only: Optional[bool] = False
) -> BaseDataset:

if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy
resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
self.resampling_strategy_args

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
InputValidator = TabularInputValidator(
is_classification=True,
logger_port=self._logger_port,
)

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=InputValidator,
resampling_strategy=resampling_strategy,
resampling_strategy_args=resampling_strategy_args,
dataset_name=dataset_name
)
if not return_only:
self.InputValidator = InputValidator
self.dataset = dataset

return dataset

def search(
self,
optimize_metric: str,
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
X_train: Union[List, pd.DataFrame, np.ndarray],
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Union[List, pd.DataFrame, np.ndarray],
y_test: Union[List, pd.DataFrame, np.ndarray],
dataset_name: Optional[str] = None,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
Expand All @@ -143,6 +194,8 @@ def search(
A pair of features (X_train) and targets (y_train) used to fit a
pipeline. Additionally, a holdout of this pairs (X_test, y_test) can
be provided to track the generalization performance of each stage.
dataset_name (Optional[str]):
Name of the dayaset, if None, random value is used
optimize_metric (str): name of the metric that is used to
evaluate a pipeline.
budget_type (Optional[str]):
Expand Down Expand Up @@ -204,31 +257,12 @@ def search(
self

"""
if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

# we have to create a logger for at this point for the validator
self._logger = self._get_logger(dataset_name)

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
self.InputValidator = TabularInputValidator(
is_classification=True,
logger_port=self._logger_port,
)

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

self.dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=self.InputValidator,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
)
self.get_dataset(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
dataset_name=dataset_name)

return self._search(
dataset=self.dataset,
Expand Down
86 changes: 60 additions & 26 deletions autoPyTorch/api/tabular_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,59 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An
'numerical_columns': dataset.numerical_columns,
'categorical_columns': dataset.categorical_columns}

def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline:
return TabularRegressionPipeline(dataset_properties=dataset_properties)
def build_pipeline(self, dataset_properties: Dict[str, Any],
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> TabularRegressionPipeline:
return TabularRegressionPipeline(dataset_properties=dataset_properties,
include=include_components,
exclude=exclude_components,
search_space_updates=search_space_updates)

def get_dataset(self,
X_train: Union[List, pd.DataFrame, np.ndarray],
y_train: Union[List, pd.DataFrame, np.ndarray],
X_test: Union[List, pd.DataFrame, np.ndarray],
y_test: Union[List, pd.DataFrame, np.ndarray],
resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
dataset_name: Optional[str] = None,
return_only: Optional[bool] = False
) -> BaseDataset:

if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy
resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \
self.resampling_strategy_args

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
InputValidator = TabularInputValidator(
is_classification=False,
logger_port=self._logger_port,
)

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=InputValidator,
resampling_strategy=resampling_strategy,
resampling_strategy_args=resampling_strategy_args,
dataset_name=dataset_name
)
if not return_only:
self.InputValidator = InputValidator
self.dataset = dataset

return dataset

def search(
self,
Expand Down Expand Up @@ -192,31 +243,14 @@ def search(
self

"""
if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))

# we have to create a logger for at this point for the validator
self._logger = self._get_logger(dataset_name)

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
self.InputValidator = TabularInputValidator(
is_classification=False,
logger_port=self._logger_port,
)

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

self.dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=self.InputValidator,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
)
self.get_dataset(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
dataset_name=dataset_name)

return self._search(
dataset=self.dataset,
Expand Down
13 changes: 10 additions & 3 deletions autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import math
import multiprocessing
import os
import time
import traceback
import typing
Expand All @@ -25,6 +26,7 @@
from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.common import replace_string_bool_to_bool
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger

Expand Down Expand Up @@ -105,7 +107,7 @@ def __init__(
include: typing.Optional[typing.Dict[str, typing.Any]] = None,
exclude: typing.Optional[typing.Dict[str, typing.Any]] = None,
memory_limit: typing.Optional[int] = None,
disable_file_output: bool = False,
disable_file_output: typing.Union[bool, typing.List] = False,
init_params: typing.Dict[str, typing.Any] = None,
budget_type: str = None,
ta: typing.Optional[typing.Callable] = None,
Expand Down Expand Up @@ -144,7 +146,12 @@ def __init__(
self.exclude = exclude
self.disable_file_output = disable_file_output
self.init_params = init_params
self.pipeline_config = pipeline_config
self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict()
if pipeline_config is None:
pipeline_config = replace_string_bool_to_bool(json.load(open(
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))
self.pipeline_config.update(pipeline_config)

self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type
self.logger_port = logger_port
if self.logger_port is None:
Expand Down Expand Up @@ -199,7 +206,7 @@ def run_wrapper(
)
else:
if run_info.budget == 0:
run_info = run_info._replace(budget=100.0)
run_info = run_info._replace(budget=self.pipeline_config[self.budget_type])
elif run_info.budget <= 0 or run_info.budget > 100:
raise ValueError('Illegal value for budget, must be >0 and <=100, but is %f' %
run_info.budget)
Expand Down
2 changes: 1 addition & 1 deletion autoPyTorch/optimizer/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self,
resampling_strategy_args: typing.Optional[typing.Dict[str, typing.Any]] = None,
include: typing.Optional[typing.Dict[str, typing.Any]] = None,
exclude: typing.Optional[typing.Dict[str, typing.Any]] = None,
disable_file_output: typing.List = [],
disable_file_output: typing.Union[bool, typing.List] = [],
smac_scenario_args: typing.Optional[typing.Dict[str, typing.Any]] = None,
get_smac_object_callback: typing.Optional[typing.Callable] = None,
all_supported_metrics: bool = True,
Expand Down
85 changes: 85 additions & 0 deletions examples/tabular/40_advanced/example_single_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# -*- encoding: utf-8 -*-
"""
==========================
Fit a single configuration
==========================
*Auto-PyTorch* searches for the best combination of machine learning algorithms
and their hyper-parameter configuration for a given task.

This example shows how one can fit one of these pipelines, both, with a user defined
configuration, and a randomly sampled one form the configuration space.
The pipelines that Auto-PyTorch fits are compatible with Scikit-Learn API. You can
get further documentation about Scikit-Learn models here: <https://scikit-learn.org/stable/getting_started.html`>_
"""
import os
import tempfile as tmp
import warnings

os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir()
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)

import sklearn.datasets
import sklearn.metrics

from autoPyTorch.api.tabular_classification import TabularClassificationTask
from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes


if __name__ == '__main__':
############################################################################
# Data Loading
# ============

X, y = sklearn.datasets.fetch_openml(data_id=3, return_X_y=True, as_frame=True)
X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split(
X, y, test_size=0.5, random_state=3
)

############################################################################
# Define an estimator
# ============================

# Search for a good configuration
estimator = TabularClassificationTask(
resampling_strategy=HoldoutValTypes.holdout_validation,
resampling_strategy_args={'val_share': 0.33}
)

############################################################################
# Get a random configuration of the pipeline for current dataset
# ===============================================================

dataset = estimator.get_dataset(X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test)
configuration = estimator.get_search_space(dataset).get_default_configuration()

###########################################################################
# Fit the configuration
# ==================================

pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, y_train=y_train,
dataset_name='kr-vs-kp',
X_test=X_test, y_test=y_test,
disable_file_output=False,
configuration=configuration
)

# This object complies with Scikit-Learn Pipeline API.
# https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html
print(pipeline.named_steps)

# The fit_pipeline command also returns a named tuple with the pipeline constraints
print(run_info)

# The fit_pipeline command also returns a named tuple with train/test performance
print(run_value)

print("Passed Configuration:", pipeline.config)
print("Network:", pipeline.named_steps['network'].network)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
"codecov",
"pep8",
"mypy",
"openml"
"openml",
"pytest-forked"
],
"examples": [
"matplotlib",
Expand Down
Loading