From 969ba0ece31b34c1676da4673424e0815725e60f Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 15 Apr 2021 15:18:43 +0200 Subject: [PATCH 01/14] Working fit_pipeline method, with test and example --- .github/workflows/examples.yml | 1 + autoPyTorch/api/base_task.py | 330 ++++++++++++++---- autoPyTorch/api/tabular_classification.py | 85 +++-- autoPyTorch/api/tabular_regression.py | 82 +++-- autoPyTorch/evaluation/tae.py | 2 +- autoPyTorch/optimizer/smbo.py | 2 +- .../example_single_configuration.py | 83 +++++ test/test_api/test_api.py | 105 +++++- 8 files changed, 564 insertions(+), 126 deletions(-) create mode 100644 examples/tabular/40_advanced/example_single_configuration.py diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 59f70facf..7f8ef50d5 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -34,4 +34,5 @@ jobs: python examples/tabular/20_basics/example_tabular_regression.py python examples/tabular/40_advanced/example_custom_configuration_space.py python examples/tabular/40_advanced/example_resampling_strategy.py + python examples/tabular/40_advanced/example_single_configuration.py python examples/example_image_classification.py \ No newline at end of file diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index c4fa0e7ce..9b9f25a6b 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -13,7 +13,7 @@ import uuid import warnings from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from ConfigSpace.configuration_space import Configuration, ConfigurationSpace @@ -25,7 +25,7 @@ import pandas as pd -from smac.runhistory.runhistory import DataOrigin, RunHistory +from smac.runhistory.runhistory import DataOrigin, RunHistory, RunInfo, RunValue from smac.stats.stats import Stats from smac.tae import StatusType @@ -122,6 +122,15 @@ class BaseTask: exclude_components (Optional[Dict]): If None, all possible components are used. Otherwise specifies set of components not to use. Incompatible with include components + search_space_updates (Optional[HyperparameterSearchSpaceUpdates]): updates to be made + to the hyperparameter search space of the pipeline + resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), + (default=HoldoutValTypes.holdout_validation): + strategy to split the training data. + resampling_strategy_args (Optional[Dict[str, Any]]): arguments + required for the chosen resampling strategy. If None, uses + the default values provided in DEFAULT_RESAMPLING_PARAMETERS + in ```datasets/resampling_strategy.py```. """ def __init__( @@ -144,6 +153,26 @@ def __init__( search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, task_type: Optional[str] = None ) -> None: + """ + + Args: + seed: + n_jobs: + logging_config: + ensemble_size: + ensemble_nbest: + max_models_on_disc: + temporary_directory: + output_directory: + delete_tmp_folder_after_terminate: + delete_output_folder_after_terminate: + include_components: + exclude_components: + backend: + + : + task_type: + """ self.seed = seed self.n_jobs = n_jobs self.ensemble_size = ensemble_size @@ -205,7 +234,11 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An raise NotImplementedError @abstractmethod - def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline: + def build_pipeline(self, dataset_properties: Dict[str, Any], + include_components: Optional[Dict] = None, + exclude_components: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> BasePipeline: """ Build pipeline according to current task and for the passed dataset properties @@ -215,7 +248,23 @@ def build_pipeline(self, dataset_properties: Dict[str, Any]) -> BasePipeline: Returns: """ - raise NotImplementedError + + raise NotImplementedError("Function called on BaseTask, this can only be called by " + "specific task which is a child of the BaseTask") + + @abstractmethod + def _create_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: + raise NotImplementedError("Function called on BaseTask, this can only be called by " + "specific task which is a child of the BaseTask") def set_pipeline_config( self, @@ -396,9 +445,9 @@ def _close_dask_client(self) -> None: None """ if ( - hasattr(self, '_is_dask_client_internally_created') - and self._is_dask_client_internally_created - and self._dask_client + hasattr(self, '_is_dask_client_internally_created') + and self._is_dask_client_internally_created + and self._dask_client ): self._dask_client.shutdown() self._dask_client.close() @@ -420,6 +469,13 @@ def _load_models(self) -> bool: raise ValueError("Resampling strategy is needed to determine what models to load") self.ensemble_ = self._backend.load_ensemble(self.seed) + if isinstance(self._disable_file_output, List): + disabled_file_outputs = self._disable_file_output + disable_file_output = False + elif isinstance(self._disable_file_output, bool): + disable_file_output = self._disable_file_output + disabled_file_outputs = [] + # If no ensemble is loaded, try to get the best performing model if not self.ensemble_: self.ensemble_ = self._load_best_individual_model() @@ -434,7 +490,7 @@ def _load_models(self) -> bool: if len(self.cv_models_) == 0: raise ValueError('No models fitted!') - elif 'pipeline' not in self._disable_file_output: + elif disable_file_output or 'pipeline' not in disabled_file_outputs: model_names = self._backend.list_all_models(self.seed) if len(model_names) == 0: @@ -516,7 +572,7 @@ def _do_dummy_prediction(self) -> None: initial_num_run=num_run, stats=stats, memory_limit=memory_limit, - disable_file_output=True if len(self._disable_file_output) > 0 else False, + disable_file_output=self._disable_file_output, all_supported_metrics=self._all_supported_metrics ) @@ -609,7 +665,7 @@ def _do_traditional_prediction(self, time_left: int, func_eval_time_limit_secs: initial_num_run=self._backend.get_next_num_run(), stats=stats, memory_limit=memory_limit, - disable_file_output=True if len(self._disable_file_output) > 0 else False, + disable_file_output=self._disable_file_output, all_supported_metrics=self._all_supported_metrics ) dask_futures.append([ @@ -698,7 +754,7 @@ def _search( get_smac_object_callback: Optional[Callable] = None, all_supported_metrics: bool = True, precision: int = 32, - disable_file_output: List = [], + disable_file_output: Union[bool, List] = False, load_models: bool = True, ) -> 'BaseTask': """ @@ -1008,10 +1064,10 @@ def _search( return self def refit( - self, - dataset: BaseDataset, - budget_config: Dict[str, Union[int, str]] = {}, - split_id: int = 0 + self, + dataset: BaseDataset, + budget_config: Dict[str, Union[int, str]] = {}, + split_id: int = 0 ) -> "BaseTask": """ Refit all models found with fit to new data. @@ -1079,37 +1135,112 @@ def refit( return self - def fit(self, - dataset: BaseDataset, - budget_config: Dict[str, Union[int, str]] = {}, - pipeline_config: Optional[Configuration] = None, - split_id: int = 0) -> BasePipeline: + def fit_pipeline(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + dataset_name: Optional[str] = None, + resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + run_time_limit_secs: int = 60, + memory_limit: Optional[int] = None, + eval_metric: Optional[str] = None, + all_supported_metrics: bool = False, + budget_type: Optional[str] = None, + include_components: Optional[Dict] = None, + exclude_components: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, + budget: float = 50, + configuration: Optional[Configuration] = None, + pipeline_options: Optional[Dict] = None, + disable_file_output: Optional[Union[bool, List]] = False, + return_dataset: bool = True + ) -> Tuple[Optional[BasePipeline], RunInfo, RunValue, Optional[BaseDataset]]: + """ Fit a pipeline on the given task for the budget. A pipeline configuration can be specified if None, uses default + Args: - dataset: (Dataset) - The argument that will provide the dataset splits. It can either - be a dictionary with the splits, or the dataset object which can - generate the splits based on different restrictions. - budget_config: (Optional[Dict[str, Union[int, str]]]) - can contain keys from 'budget_type' and the budget - specified using 'epochs' or 'runtime'. - split_id: (int) (default=0) - split id to fit on. - pipeline_config: (Optional[Configuration]) + X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame] + A pair of features (X_train) and targets (y_train) used to fit a + pipeline. Additionally, a holdout of this pairs (X_test, y_test) can + be provided to track the generalization performance of each stage. + dataset_name (Optional[str]): + Name of the dayaset, if None, random value is used. + resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), + (default=HoldoutValTypes.holdout_validation): + strategy to split the training data. + resampling_strategy_args (Optional[Dict[str, Any]]): arguments + required for the chosen resampling strategy. If None, uses + the default values provided in DEFAULT_RESAMPLING_PARAMETERS + in ```datasets/resampling_strategy.py```. + run_time_limit_secs (int), (default=120): Time limit + for a single call to the machine learning model. + Model fitting will be terminated if the machine + learning algorithm runs over the time limit. Set + this value high enough so that typical machine + learning algorithms can be fit on the training + data. + memory_limit (Optional[int]), (default=None): Memory + limit in MB for the machine learning algorithm. autopytorch + will stop fitting the machine learning algorithm if it tries + to allocate more than memory_limit MB. If None is provided, + no memory limit is set. In case of multi-processing, memory_limit + will be per job. This memory limit also applies to the ensemble + creation process. + eval_metric (str): name of the metric that is used to + evaluate a pipeline. + all_supported_metrics (bool), (default=True): if True, all + metrics supporting current task will be calculated + for each pipeline and results will be available via cv_results + budget_type (Optional[str]): + Type of budget to be used when fitting the pipeline. + Either 'epochs' or 'runtime'. If not provided, uses + the default in the pipeline config ('epochs') + include_components (Optional[Dict]): If None, all possible components are used. + Otherwise specifies set of components to use. + exclude_components (Optional[Dict]): If None, all possible components are used. + Otherwise specifies set of components not to use. Incompatible with include + components + search_space_updates(Optional[HyperparameterSearchSpaceUpdates]): updates to be made + to the hyperparameter search space of the pipeline + budget (Optional[float]): + Budget to fit a single run of the pipeline. If not + provided, uses the default in the pipeline config + pipeline_options (Optional[Dict]): + Valid config options include "device", + "torch_num_threads", "early_stopping", "use_tensorboard_logger", + "metrics_during_training" + disable_file_output (Optional[Union[bool, List]]): + By default, the model, it's predictions and other metadata is stored on disk + for each finished configuration. This argument allows the user to skip + saving certain file type, for example the model, from being written to disk. + configuration: (Optional[Configuration]) configuration to fit the pipeline with. If None, - uses default + uses default. Returns: (BasePipeline): fitted pipeline + (RunInfo): Run information + (RunValue): Result of fitting the pipeline + (BaseDataset): Dataset created from the given tensors """ - if self.dataset_name is None: - self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) - if self._logger is None: - self._logger = self._get_logger(self.dataset_name) + resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy + resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ + self.resampling_strategy_args + + dataset = self._create_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=resampling_strategy, + resampling_strategy_args=resampling_strategy_args, + dataset_name=dataset_name, + return_only=True) # get dataset properties dataset_requirements = get_dataset_requirements( @@ -1117,35 +1248,110 @@ def fit(self, dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._backend.save_datamanager(dataset) + self._backend._make_internals_directory() + + if self._logger is None: + self._logger = self._get_logger(dataset.dataset_name) + # build pipeline - pipeline = self.build_pipeline(dataset_properties) - if pipeline_config is not None: - pipeline.set_hyperparameters(pipeline_config) + if include_components is None: + include_components = self.include_components + if exclude_components is None: + exclude_components = self.exclude_components + if search_space_updates is None: + search_space_updates = self.search_space_updates + + pipeline = self.build_pipeline(dataset_properties=dataset_properties, + include_components=include_components, + exclude_components=exclude_components, + search_space_updates=search_space_updates) + if configuration is None: + configuration = pipeline.get_hyperparameter_search_space().get_default_configuration() + configuration.__setattr__('config_id', 0) + scenario_mock = unittest.mock.Mock() + scenario_mock.wallclock_limit = run_time_limit_secs + # This stats object is a hack - maybe the SMAC stats object should + # already be generated here! + stats = Stats(scenario_mock) - # initialise fit dictionary - X: Dict[str, Any] = dict({'dataset_properties': dataset_properties, - 'backend': self._backend, - 'X_train': dataset.train_tensors[0], - 'y_train': dataset.train_tensors[1], - 'X_test': dataset.test_tensors[0] if dataset.test_tensors is not None else None, - 'y_test': dataset.test_tensors[1] if dataset.test_tensors is not None else None, - 'train_indices': dataset.splits[split_id][0], - 'val_indices': dataset.splits[split_id][1], - 'split_id': split_id, - 'num_run': self._backend.get_next_num_run(), - }) - X.update({**self.pipeline_options, **budget_config}) + if memory_limit is None: + if hasattr(self, '_memory_limit') and self._memory_limit is not None: + memory_limit = self._memory_limit + + metric = get_metrics(dataset_properties=dataset_properties, + names=[eval_metric] if eval_metric is not None else None, + all_supported_metrics=False).pop() + + pipeline_options = self.pipeline_options.copy().update(pipeline_options) if pipeline_options is not None \ + else self.pipeline_options.copy() + if budget_type is not None: + assert pipeline_options is not None + pipeline_options.update({'budget_type': budget_type}) + if disable_file_output is None: + disable_file_output = self._disable_file_output + stats.start_timing() + tae = ExecuteTaFuncWithQueue( + backend=self._backend, + seed=self.seed, + metric=metric, + logger_port=self._logger_port, + cost_for_crash=get_cost_of_crash(metric), + abort_on_first_run_crash=False, + initial_num_run=self._backend.get_next_num_run(), + stats=stats, + memory_limit=memory_limit, + disable_file_output=disable_file_output, + all_supported_metrics=all_supported_metrics, + budget_type=budget_type, + include=include_components, + exclude=exclude_components, + search_space_updates=search_space_updates, + pipeline_config=pipeline_options + ) - fit_and_suppress_warnings(self._logger, pipeline, X, y=None) + run_info, run_value = tae.run_wrapper( + RunInfo(config=configuration, + budget=budget, + seed=self.seed, + cutoff=run_time_limit_secs, + capped=False, + instance_specific=None, + instance=None) + ) + disabled_file_outputs: List = [] + if isinstance(disable_file_output, bool): + disable_file_output = disable_file_output + elif isinstance(disable_file_output, List): + disabled_file_outputs = disable_file_output + else: + raise ValueError('disable_file_output should be either a bool or a list') + + fitted_pipeline: Optional[BasePipeline] = None + if disable_file_output or 'pipeline' in disabled_file_outputs: + self._logger.warning("File output is disabled. No pipeline can returned") + elif run_value.status == StatusType.SUCCESS: + if self.resampling_strategy in CrossValTypes: + load_function = self._backend.load_cv_model_by_seed_and_id_and_budget + else: + load_function = self._backend.load_model_by_seed_and_id_and_budget + fitted_pipeline = load_function( + seed=self.seed, + idx=run_info.config.config_id + tae.initial_num_run, + budget=float(run_info.budget), + ) self._clean_logger() - return pipeline + + if not return_dataset: + dataset = None # type: ignore [assignment] + + return fitted_pipeline, run_info, run_value, dataset def predict( - self, - X_test: np.ndarray, - batch_size: Optional[int] = None, - n_jobs: int = 1 + self, + X_test: np.ndarray, + batch_size: Optional[int] = None, + n_jobs: int = 1 ) -> np.ndarray: """Generate the estimator predictions. Generate the predictions based on the given examples from the test set. @@ -1195,9 +1401,9 @@ def predict( return predictions def score( - self, - y_pred: np.ndarray, - y_test: Union[np.ndarray, pd.DataFrame] + self, + y_pred: np.ndarray, + y_test: Union[np.ndarray, pd.DataFrame] ) -> Dict[str, float]: """Calculate the score on the test set. Calculate the evaluation measure on the test set. @@ -1239,13 +1445,13 @@ def __del__(self) -> None: @typing.no_type_check def get_incumbent_results( - self + self ): pass @typing.no_type_check def get_incumbent_config( - self + self ): pass diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index deeb5244b..ca09267c5 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -108,16 +108,60 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An 'numerical_columns': dataset.numerical_columns, 'categorical_columns': dataset.categorical_columns} - def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline: + def build_pipeline(self, dataset_properties: Dict[str, Any], + include_components: Optional[Dict] = None, + exclude_components: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> TabularClassificationPipeline: return TabularClassificationPipeline(dataset_properties=dataset_properties) + def _create_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: + + if dataset_name is None: + dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + + # Create a validator object to make sure that the data provided by + # the user matches the autopytorch requirements + InputValidator = TabularInputValidator( + is_classification=True, + logger_port=self._logger_port, + ) + + # Fit a input validator to check the provided data + # Also, an encoder is fit to both train and test data, + # to prevent unseen categories during inference + InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) + + dataset = TabularDataset( + X=X_train, Y=y_train, + X_test=X_test, Y_test=y_test, + validator=InputValidator, + resampling_strategy=self.resampling_strategy, + resampling_strategy_args=self.resampling_strategy_args, + dataset_name=dataset_name + ) + if not return_only: + self.InputValidator = InputValidator + self.dataset = dataset + + return dataset + def search( self, optimize_metric: str, - X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, - y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], dataset_name: Optional[str] = None, budget_type: Optional[str] = None, budget: Optional[float] = None, @@ -143,6 +187,8 @@ def search( A pair of features (X_train) and targets (y_train) used to fit a pipeline. Additionally, a holdout of this pairs (X_test, y_test) can be provided to track the generalization performance of each stage. + dataset_name (Optional[str]): + Name of the dayaset, if None, random value is used optimize_metric (str): name of the metric that is used to evaluate a pipeline. budget_type (Optional[str]): @@ -204,31 +250,12 @@ def search( self """ - if dataset_name is None: - dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) - - # we have to create a logger for at this point for the validator - self._logger = self._get_logger(dataset_name) - - # Create a validator object to make sure that the data provided by - # the user matches the autopytorch requirements - self.InputValidator = TabularInputValidator( - is_classification=True, - logger_port=self._logger_port, - ) - # Fit a input validator to check the provided data - # Also, an encoder is fit to both train and test data, - # to prevent unseen categories during inference - self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) - - self.dataset = TabularDataset( - X=X_train, Y=y_train, - X_test=X_test, Y_test=y_test, - validator=self.InputValidator, - resampling_strategy=self.resampling_strategy, - resampling_strategy_args=self.resampling_strategy_args, - ) + self._create_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + dataset_name=dataset_name) return self._search( dataset=self.dataset, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index afef8ce9f..99c0237ba 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -100,8 +100,55 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An 'numerical_columns': dataset.numerical_columns, 'categorical_columns': dataset.categorical_columns} - def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularRegressionPipeline: - return TabularRegressionPipeline(dataset_properties=dataset_properties) + def build_pipeline(self, dataset_properties: Dict[str, Any], + include_components: Optional[Dict] = None, + exclude_components: Optional[Dict] = None, + search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None + ) -> TabularRegressionPipeline: + return TabularRegressionPipeline(dataset_properties=dataset_properties, + include=include_components, + exclude=exclude_components, + search_space_updates=search_space_updates) + + def _create_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: + + if dataset_name is None: + dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + + # Create a validator object to make sure that the data provided by + # the user matches the autopytorch requirements + InputValidator = TabularInputValidator( + is_classification=True, + logger_port=self._logger_port, + ) + + # Fit a input validator to check the provided data + # Also, an encoder is fit to both train and test data, + # to prevent unseen categories during inference + InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) + + dataset = TabularDataset( + X=X_train, Y=y_train, + X_test=X_test, Y_test=y_test, + validator=InputValidator, + resampling_strategy=resampling_strategy, + resampling_strategy_args=resampling_strategy_args, + dataset_name=dataset_name + ) + if not return_only: + self.InputValidator = InputValidator + self.dataset = dataset + + return dataset def search( self, @@ -192,31 +239,14 @@ def search( self """ - if dataset_name is None: - dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) - - # we have to create a logger for at this point for the validator - self._logger = self._get_logger(dataset_name) - # Create a validator object to make sure that the data provided by - # the user matches the autopytorch requirements - self.InputValidator = TabularInputValidator( - is_classification=False, - logger_port=self._logger_port, - ) - - # Fit a input validator to check the provided data - # Also, an encoder is fit to both train and test data, - # to prevent unseen categories during inference - self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) - - self.dataset = TabularDataset( - X=X_train, Y=y_train, - X_test=X_test, Y_test=y_test, - validator=self.InputValidator, - resampling_strategy=self.resampling_strategy, - resampling_strategy_args=self.resampling_strategy_args, - ) + self._create_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=self.resampling_strategy, + resampling_strategy_args=self.resampling_strategy_args, + dataset_name=dataset_name) return self._search( dataset=self.dataset, diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 1ef4f552d..c8eafccab 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -105,7 +105,7 @@ def __init__( include: typing.Optional[typing.Dict[str, typing.Any]] = None, exclude: typing.Optional[typing.Dict[str, typing.Any]] = None, memory_limit: typing.Optional[int] = None, - disable_file_output: bool = False, + disable_file_output: typing.Union[bool, typing.List] = False, init_params: typing.Dict[str, typing.Any] = None, budget_type: str = None, ta: typing.Optional[typing.Callable] = None, diff --git a/autoPyTorch/optimizer/smbo.py b/autoPyTorch/optimizer/smbo.py index c00965bbb..1478f83b5 100644 --- a/autoPyTorch/optimizer/smbo.py +++ b/autoPyTorch/optimizer/smbo.py @@ -97,7 +97,7 @@ def __init__(self, resampling_strategy_args: typing.Optional[typing.Dict[str, typing.Any]] = None, include: typing.Optional[typing.Dict[str, typing.Any]] = None, exclude: typing.Optional[typing.Dict[str, typing.Any]] = None, - disable_file_output: typing.List = [], + disable_file_output: typing.Union[bool, typing.List] = [], smac_scenario_args: typing.Optional[typing.Dict[str, typing.Any]] = None, get_smac_object_callback: typing.Optional[typing.Callable] = None, all_supported_metrics: bool = True, diff --git a/examples/tabular/40_advanced/example_single_configuration.py b/examples/tabular/40_advanced/example_single_configuration.py new file mode 100644 index 000000000..4f9068a39 --- /dev/null +++ b/examples/tabular/40_advanced/example_single_configuration.py @@ -0,0 +1,83 @@ +# -*- encoding: utf-8 -*- +""" +========================== +Fit a single configuration +========================== +*Auto-PyTorch* searches for the best combination of machine learning algorithms +and their hyper-parameter configuration for a given task. + +This example shows how one can fit one of these pipelines, both, with a user defined +configuration, and a randomly sampled one form the configuration space. +The pipelines that Auto-PyTorch fits are compatible with Scikit-Learn API. You can +get further documentation about Scikit-Learn models here: _ +""" +import os +import tempfile as tmp +import warnings + +os.environ['JOBLIB_TEMP_FOLDER'] = tmp.gettempdir() +os.environ['OMP_NUM_THREADS'] = '1' +os.environ['OPENBLAS_NUM_THREADS'] = '1' +os.environ['MKL_NUM_THREADS'] = '1' + +warnings.simplefilter(action='ignore', category=UserWarning) +warnings.simplefilter(action='ignore', category=FutureWarning) + +import sklearn.datasets +import sklearn.metrics + +from autoPyTorch.api.tabular_classification import TabularClassificationTask +from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes + + +if __name__ == '__main__': + ############################################################################ + # Data Loading + # ============ + + X, y = sklearn.datasets.fetch_openml(data_id=3, return_X_y=True, as_frame=True) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, test_size=0.5, random_state=3 + ) + + ############################################################################ + # Define an estimator + # ============================ + + # Search for a good configuration + estimator = TabularClassificationTask( + resampling_strategy=HoldoutValTypes.holdout_validation, + resampling_strategy_args={'val_share': 0.33} + ) + + ########################################################################### + # Fit an user provided configuration + # ================================== + + # We will create a configuration that has a user defined + # min_samples_split in the Random Forest. We recommend you to look into + # how the ConfigSpace package works here: + # https://automl.github.io/ConfigSpace/master/ + + pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, y_train=y_train, + dataset_name='kr-vs-kp', + X_test=X_test, y_test=y_test, + resampling_strategy=estimator.resampling_strategy, + resampling_strategy_args=estimator. + resampling_strategy_args, + disable_file_output=False, + ) + + # This object complies with Scikit-Learn Pipeline API. + # https://scikit-learn.org/stable/modules/generated/sklearn.pipeline.Pipeline.html + print(pipeline.named_steps) + + # The fit_pipeline command also returns a named tuple with the pipeline constraints + print(run_info) + + # The fit_pipeline command also returns a named tuple with train/test performance + print(run_value) + + # We can make sure that our pipeline configuration was honored as follows + print("Passed Configuration:", pipeline.config) + print("Random Forest:", pipeline.named_steps['network'].choice.network) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 7866e7674..6f8bed18b 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -1,31 +1,35 @@ import os import pickle import sys +import tempfile import unittest +from ConfigSpace.configuration_space import Configuration + import numpy as np import pandas as pd import pytest - import sklearn import sklearn.datasets -from sklearn.base import clone +from sklearn.base import BaseEstimator, clone from sklearn.ensemble import VotingClassifier, VotingRegressor -from smac.runhistory.runhistory import RunHistory +from smac.runhistory.runhistory import RunHistory, RunInfo, RunValue import torch from autoPyTorch.api.tabular_classification import TabularClassificationTask from autoPyTorch.api.tabular_regression import TabularRegressionTask +from autoPyTorch.datasets.base_dataset import BaseDataset from autoPyTorch.datasets.resampling_strategy import ( CrossValTypes, HoldoutValTypes, ) from autoPyTorch.optimizer.smbo import AutoMLSMBO +from autoPyTorch.pipeline.base_pipeline import BasePipeline from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy @@ -35,12 +39,11 @@ # Test # ======== -@pytest.mark.parametrize('openml_id', (40981, )) +@pytest.mark.parametrize('openml_id', (40981,)) @pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, CrossValTypes.k_fold_cross_validation, )) def test_tabular_classification(openml_id, resampling_strategy, backend): - # Get the data and check that contents of data-manager make sense X, y = sklearn.datasets.fetch_openml( data_id=int(openml_id), @@ -194,12 +197,11 @@ def test_tabular_classification(openml_id, resampling_strategy, backend): restored_estimator.predict(X_test) -@pytest.mark.parametrize('openml_name', ("boston", )) +@pytest.mark.parametrize('openml_name', ("boston",)) @pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, CrossValTypes.k_fold_cross_validation, )) def test_tabular_regression(openml_name, resampling_strategy, backend): - # Get the data and check that contents of data-manager make sense X, y = sklearn.datasets.fetch_openml( openml_name, @@ -449,3 +451,92 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): estimator._clean_logger() del estimator + + +@pytest.mark.parametrize("disable_file_output", [True, False]) +@pytest.mark.parametrize('openml_id', (40981,)) +@pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, + CrossValTypes.k_fold_cross_validation, + )) +def test_pipeline_fit(openml_id, resampling_strategy, backend, disable_file_output): + # Get the data and check that contents of data-manager make sense + X, y = sklearn.datasets.fetch_openml( + data_id=int(openml_id), + return_X_y=True, as_frame=True + ) + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, random_state=1) + + # Search for a good configuration + estimator = TabularClassificationTask( + backend=backend, + resampling_strategy=resampling_strategy, + ) + + pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=resampling_strategy, + run_time_limit_secs=50, + disable_file_output=disable_file_output + ) + assert isinstance(dataset, BaseDataset) + assert isinstance(run_info, RunInfo) + assert isinstance(run_info.config, Configuration) + + assert isinstance(run_value, RunValue) + assert 'SUCCESS' in str(run_value.status) + + if not disable_file_output: + if resampling_strategy in CrossValTypes: + pytest.skip("Bug, Can't predict with cross validation pipeline") + assert isinstance(pipeline, BaseEstimator) + X_test = dataset.test_tensors[0] + preds = pipeline.predict(X_test) + assert isinstance(preds, np.ndarray) + + score = accuracy(dataset.test_tensors[1], preds) + assert isinstance(score, float) + assert score > 0.8 + else: + assert isinstance(pipeline, BasePipeline) + # To make sure we fitted the model, there should be a + # run summary object with accuracy + run_summary = pipeline.named_steps['trainer'].run_summary + assert run_summary is not None + X_test = dataset.test_tensors[0] + preds = pipeline.predict(X_test) + assert isinstance(preds, np.ndarray) + + score = accuracy(dataset.test_tensors[1], preds) + assert isinstance(score, float) + assert score > 0.8 + else: + assert pipeline is None + assert run_value.cost < 0.2 + + # Make sure that the pipeline can be pickled + dump_file = os.path.join(tempfile.gettempdir(), 'automl.dump.pkl') + with open(dump_file, 'wb') as f: + pickle.dump(pipeline, f) + + num_run_dir = estimator._backend.get_numrun_directory( + run_info.seed, run_value.additional_info['num_run'], budget=50.0) + + cv_model_path = os.path.join(num_run_dir, estimator._backend.get_cv_model_filename( + run_info.seed, run_value.additional_info['num_run'], budget=50.0)) + model_path = os.path.join(num_run_dir, estimator._backend.get_model_filename( + run_info.seed, run_value.additional_info['num_run'], budget=50.0)) + + if disable_file_output: + # No file output is expected + assert not os.path.exists(num_run_dir) + else: + # We expect the model path always + # And the cv model only on 'cv' + assert os.path.exists(model_path) + if resampling_strategy in CrossValTypes: + assert os.path.exists(cv_model_path) + elif resampling_strategy in HoldoutValTypes: + assert not os.path.exists(cv_model_path) From 8cbcf50acd5016e68774b984e5002451fb148c5e Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 15 Apr 2021 17:32:01 +0200 Subject: [PATCH 02/14] Fixed bug in tabular regression --- autoPyTorch/api/tabular_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index 99c0237ba..d9f4930f7 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -127,7 +127,7 @@ def _create_dataset(self, # Create a validator object to make sure that the data provided by # the user matches the autopytorch requirements InputValidator = TabularInputValidator( - is_classification=True, + is_classification=False, logger_port=self._logger_port, ) From c61128d16c55aadb6879c3b573415b18ed72576c Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Thu, 15 Apr 2021 20:35:17 +0200 Subject: [PATCH 03/14] Fix bug in example single configuration --- .../40_advanced/example_single_configuration.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/examples/tabular/40_advanced/example_single_configuration.py b/examples/tabular/40_advanced/example_single_configuration.py index 4f9068a39..4ad84bca3 100644 --- a/examples/tabular/40_advanced/example_single_configuration.py +++ b/examples/tabular/40_advanced/example_single_configuration.py @@ -51,14 +51,9 @@ ) ########################################################################### - # Fit an user provided configuration + # Fit default configuration # ================================== - # We will create a configuration that has a user defined - # min_samples_split in the Random Forest. We recommend you to look into - # how the ConfigSpace package works here: - # https://automl.github.io/ConfigSpace/master/ - pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, y_train=y_train, dataset_name='kr-vs-kp', X_test=X_test, y_test=y_test, @@ -78,6 +73,5 @@ # The fit_pipeline command also returns a named tuple with train/test performance print(run_value) - # We can make sure that our pipeline configuration was honored as follows print("Passed Configuration:", pipeline.config) - print("Random Forest:", pipeline.named_steps['network'].choice.network) + print("Network:", pipeline.named_steps['network'].network) From c611d1379dadf0afb7c5265c7b9b9bc5bcf75c32 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Apr 2021 14:37:00 +0200 Subject: [PATCH 04/14] Addressed comments from Fransisco, making configuration required in fit_pipeline --- autoPyTorch/api/base_task.py | 85 +++++++++-------------- autoPyTorch/api/tabular_classification.py | 35 +++++----- autoPyTorch/api/tabular_regression.py | 34 ++++----- test/test_api/test_api.py | 36 +++++++--- 4 files changed, 94 insertions(+), 96 deletions(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 9b9f25a6b..18a631f6e 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -153,26 +153,7 @@ def __init__( search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, task_type: Optional[str] = None ) -> None: - """ - Args: - seed: - n_jobs: - logging_config: - ensemble_size: - ensemble_nbest: - max_models_on_disc: - temporary_directory: - output_directory: - delete_tmp_folder_after_terminate: - delete_output_folder_after_terminate: - include_components: - exclude_components: - backend: - - : - task_type: - """ self.seed = seed self.n_jobs = n_jobs self.ensemble_size = ensemble_size @@ -253,16 +234,16 @@ def build_pipeline(self, dataset_properties: Dict[str, Any], "specific task which is a child of the BaseTask") @abstractmethod - def _create_dataset(self, - X_train: Union[List, pd.DataFrame, np.ndarray], - y_train: Union[List, pd.DataFrame, np.ndarray], - X_test: Union[List, pd.DataFrame, np.ndarray], - y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, - dataset_name: Optional[str] = None, - return_only: Optional[bool] = False - ) -> BaseDataset: + def get_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: raise NotImplementedError("Function called on BaseTask, this can only be called by " "specific task which is a child of the BaseTask") @@ -1136,10 +1117,12 @@ def refit( return self def fit_pipeline(self, - X_train: Union[List, pd.DataFrame, np.ndarray], - y_train: Union[List, pd.DataFrame, np.ndarray], - X_test: Union[List, pd.DataFrame, np.ndarray], - y_test: Union[List, pd.DataFrame, np.ndarray], + configuration: Configuration, + dataset: Optional[BaseDataset] = None, + X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, + y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None, dataset_name: Optional[str] = None, resampling_strategy: Optional[Union[HoldoutValTypes, CrossValTypes]] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, @@ -1152,7 +1135,6 @@ def fit_pipeline(self, exclude_components: Optional[Dict] = None, search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None, budget: float = 50, - configuration: Optional[Configuration] = None, pipeline_options: Optional[Dict] = None, disable_file_output: Optional[Union[bool, List]] = False, return_dataset: bool = True @@ -1169,7 +1151,7 @@ def fit_pipeline(self, pipeline. Additionally, a holdout of this pairs (X_test, y_test) can be provided to track the generalization performance of each stage. dataset_name (Optional[str]): - Name of the dayaset, if None, random value is used. + Name of the dataset, if None, random value is used. resampling_strategy (Union[CrossValTypes, HoldoutValTypes]), (default=HoldoutValTypes.holdout_validation): strategy to split the training data. @@ -1219,8 +1201,7 @@ def fit_pipeline(self, for each finished configuration. This argument allows the user to skip saving certain file type, for example the model, from being written to disk. configuration: (Optional[Configuration]) - configuration to fit the pipeline with. If None, - uses default. + configuration to fit the pipeline with. Returns: (BasePipeline): fitted pipeline @@ -1228,12 +1209,15 @@ def fit_pipeline(self, (RunValue): Result of fitting the pipeline (BaseDataset): Dataset created from the given tensors """ - - resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy - resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ - self.resampling_strategy_args - - dataset = self._create_dataset(X_train=X_train, + if dataset is None: + assert X_train is not None or \ + y_train is not None or \ + X_test is not None or \ + y_test is not None, "No dataset provided, must provide X_train, y_train, X_test, y_test tensors" + resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy + resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ + self.resampling_strategy_args + dataset = self.get_dataset(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test, @@ -1242,14 +1226,18 @@ def fit_pipeline(self, dataset_name=dataset_name, return_only=True) + # TAE expects each configuration to have a config_id. + # For fitting a pipeline as it is not part of the + # search process, it makes sense to set it to 0 + if hasattr(configuration, 'config_id') or configuration.config_id is None: + configuration.__setattr__('config_id', 0) + # get dataset properties dataset_requirements = get_dataset_requirements( info=self._get_required_dataset_properties(dataset)) dataset_properties = dataset.get_dataset_properties(dataset_requirements) self._backend.save_datamanager(dataset) - self._backend._make_internals_directory() - if self._logger is None: self._logger = self._get_logger(dataset.dataset_name) @@ -1261,13 +1249,6 @@ def fit_pipeline(self, if search_space_updates is None: search_space_updates = self.search_space_updates - pipeline = self.build_pipeline(dataset_properties=dataset_properties, - include_components=include_components, - exclude_components=exclude_components, - search_space_updates=search_space_updates) - if configuration is None: - configuration = pipeline.get_hyperparameter_search_space().get_default_configuration() - configuration.__setattr__('config_id', 0) scenario_mock = unittest.mock.Mock() scenario_mock.wallclock_limit = run_time_limit_secs # This stats object is a hack - maybe the SMAC stats object should diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index ca09267c5..f15fe26e9 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -113,18 +113,21 @@ def build_pipeline(self, dataset_properties: Dict[str, Any], exclude_components: Optional[Dict] = None, search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None ) -> TabularClassificationPipeline: - return TabularClassificationPipeline(dataset_properties=dataset_properties) + return TabularClassificationPipeline(dataset_properties=dataset_properties, + include=include_components, + exclude=exclude_components, + search_space_updates=search_space_updates) - def _create_dataset(self, - X_train: Union[List, pd.DataFrame, np.ndarray], - y_train: Union[List, pd.DataFrame, np.ndarray], - X_test: Union[List, pd.DataFrame, np.ndarray], - y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, - dataset_name: Optional[str] = None, - return_only: Optional[bool] = False - ) -> BaseDataset: + def get_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: if dataset_name is None: dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) @@ -251,11 +254,11 @@ def search( """ - self._create_dataset(X_train=X_train, - y_train=y_train, - X_test=X_test, - y_test=y_test, - dataset_name=dataset_name) + self.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + dataset_name=dataset_name) return self._search( dataset=self.dataset, diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index d9f4930f7..e8f750365 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -110,16 +110,16 @@ def build_pipeline(self, dataset_properties: Dict[str, Any], exclude=exclude_components, search_space_updates=search_space_updates) - def _create_dataset(self, - X_train: Union[List, pd.DataFrame, np.ndarray], - y_train: Union[List, pd.DataFrame, np.ndarray], - X_test: Union[List, pd.DataFrame, np.ndarray], - y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, - resampling_strategy_args: Optional[Dict[str, Any]] = None, - dataset_name: Optional[str] = None, - return_only: Optional[bool] = False - ) -> BaseDataset: + def get_dataset(self, + X_train: Union[List, pd.DataFrame, np.ndarray], + y_train: Union[List, pd.DataFrame, np.ndarray], + X_test: Union[List, pd.DataFrame, np.ndarray], + y_test: Union[List, pd.DataFrame, np.ndarray], + resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy_args: Optional[Dict[str, Any]] = None, + dataset_name: Optional[str] = None, + return_only: Optional[bool] = False + ) -> BaseDataset: if dataset_name is None: dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) @@ -240,13 +240,13 @@ def search( """ - self._create_dataset(X_train=X_train, - y_train=y_train, - X_test=X_test, - y_test=y_test, - resampling_strategy=self.resampling_strategy, - resampling_strategy_args=self.resampling_strategy_args, - dataset_name=dataset_name) + self.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=self.resampling_strategy, + resampling_strategy_args=self.resampling_strategy_args, + dataset_name=dataset_name) return self._search( dataset=self.dataset, diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 6f8bed18b..53f270d8d 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -31,7 +31,7 @@ from autoPyTorch.optimizer.smbo import AutoMLSMBO from autoPyTorch.pipeline.base_pipeline import BasePipeline from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy - +from autoPyTorch.utils.pipeline import get_dataset_requirements # Fixtures # ======== @@ -454,11 +454,17 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): @pytest.mark.parametrize("disable_file_output", [True, False]) -@pytest.mark.parametrize('openml_id', (40981,)) -@pytest.mark.parametrize('resampling_strategy', (HoldoutValTypes.holdout_validation, - CrossValTypes.k_fold_cross_validation, - )) -def test_pipeline_fit(openml_id, resampling_strategy, backend, disable_file_output): +@pytest.mark.parametrize('openml_id', (40984,)) +@pytest.mark.parametrize('resampling_strategy,resampling_strategy_args', + ((HoldoutValTypes.holdout_validation, {'val_share': 0.8}), + (CrossValTypes.k_fold_cross_validation, {'num_splits': 2}) + ) + ) +def test_pipeline_fit(openml_id, + resampling_strategy, + resampling_strategy_args, + backend, + disable_file_output): # Get the data and check that contents of data-manager make sense X, y = sklearn.datasets.fetch_openml( data_id=int(openml_id), @@ -473,11 +479,19 @@ def test_pipeline_fit(openml_id, resampling_strategy, backend, disable_file_outp resampling_strategy=resampling_strategy, ) - pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, - y_train=y_train, - X_test=X_test, - y_test=y_test, - resampling_strategy=resampling_strategy, + dataset = estimator.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + resampling_strategy=resampling_strategy, + resampling_strategy_args=resampling_strategy_args) + dataset_requirements = get_dataset_requirements( + info=dataset.get_required_dataset_info()) + dataset_properties = dataset.get_dataset_properties(dataset_requirements) + configuration = estimator.build_pipeline(dataset_properties).\ + get_hyperparameter_search_space().get_default_configuration() + pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset, + configuration=configuration, run_time_limit_secs=50, disable_file_output=disable_file_output ) From 8ae47599bb3c7b068364f4b52f511bc0f834402d Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Apr 2021 21:46:27 +0200 Subject: [PATCH 05/14] Add configuration to example, fix in get_dataset --- autoPyTorch/api/base_task.py | 3 --- autoPyTorch/api/tabular_classification.py | 10 ++++++--- autoPyTorch/api/tabular_regression.py | 6 +++++- .../example_single_configuration.py | 21 +++++++++++++++---- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 18a631f6e..05f942929 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -1214,9 +1214,6 @@ def fit_pipeline(self, y_train is not None or \ X_test is not None or \ y_test is not None, "No dataset provided, must provide X_train, y_train, X_test, y_test tensors" - resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy - resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ - self.resampling_strategy_args dataset = self.get_dataset(X_train=X_train, y_train=y_train, X_test=X_test, diff --git a/autoPyTorch/api/tabular_classification.py b/autoPyTorch/api/tabular_classification.py index f15fe26e9..71bb99729 100644 --- a/autoPyTorch/api/tabular_classification.py +++ b/autoPyTorch/api/tabular_classification.py @@ -123,7 +123,7 @@ def get_dataset(self, y_train: Union[List, pd.DataFrame, np.ndarray], X_test: Union[List, pd.DataFrame, np.ndarray], y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, return_only: Optional[bool] = False @@ -132,6 +132,10 @@ def get_dataset(self, if dataset_name is None: dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy + resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ + self.resampling_strategy_args + # Create a validator object to make sure that the data provided by # the user matches the autopytorch requirements InputValidator = TabularInputValidator( @@ -148,8 +152,8 @@ def get_dataset(self, X=X_train, Y=y_train, X_test=X_test, Y_test=y_test, validator=InputValidator, - resampling_strategy=self.resampling_strategy, - resampling_strategy_args=self.resampling_strategy_args, + resampling_strategy=resampling_strategy, + resampling_strategy_args=resampling_strategy_args, dataset_name=dataset_name ) if not return_only: diff --git a/autoPyTorch/api/tabular_regression.py b/autoPyTorch/api/tabular_regression.py index e8f750365..dc867c21a 100644 --- a/autoPyTorch/api/tabular_regression.py +++ b/autoPyTorch/api/tabular_regression.py @@ -115,7 +115,7 @@ def get_dataset(self, y_train: Union[List, pd.DataFrame, np.ndarray], X_test: Union[List, pd.DataFrame, np.ndarray], y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, return_only: Optional[bool] = False @@ -124,6 +124,10 @@ def get_dataset(self, if dataset_name is None: dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + resampling_strategy = resampling_strategy if resampling_strategy is not None else self.resampling_strategy + resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is not None else \ + self.resampling_strategy_args + # Create a validator object to make sure that the data provided by # the user matches the autopytorch requirements InputValidator = TabularInputValidator( diff --git a/examples/tabular/40_advanced/example_single_configuration.py b/examples/tabular/40_advanced/example_single_configuration.py index 4ad84bca3..a354b36cc 100644 --- a/examples/tabular/40_advanced/example_single_configuration.py +++ b/examples/tabular/40_advanced/example_single_configuration.py @@ -28,6 +28,7 @@ from autoPyTorch.api.tabular_classification import TabularClassificationTask from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes +from autoPyTorch.utils.pipeline import get_dataset_requirements if __name__ == '__main__': @@ -50,17 +51,29 @@ resampling_strategy_args={'val_share': 0.33} ) + ############################################################################ + # Get a random configuration of the pipeline for current dataset + # =============================================================== + + dataset = estimator.get_dataset(X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test) + dataset_requirements = get_dataset_requirements( + info=dataset.get_required_dataset_info()) + dataset_properties = dataset.get_dataset_properties(dataset_requirements) + configuration = estimator.build_pipeline(dataset_properties).\ + get_hyperparameter_search_space().get_default_configuration() + ########################################################################### - # Fit default configuration + # Fit the configuration # ================================== pipeline, run_info, run_value, dataset = estimator.fit_pipeline(X_train=X_train, y_train=y_train, dataset_name='kr-vs-kp', X_test=X_test, y_test=y_test, - resampling_strategy=estimator.resampling_strategy, - resampling_strategy_args=estimator. - resampling_strategy_args, disable_file_output=False, + configuration=configuration ) # This object complies with Scikit-Learn Pipeline API. From e41874225542aa86f56ca30a1cece193ce2d029a Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Apr 2021 21:51:38 +0200 Subject: [PATCH 06/14] Fix mypy --- autoPyTorch/api/base_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 05f942929..5ab5b6382 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -239,7 +239,7 @@ def get_dataset(self, y_train: Union[List, pd.DataFrame, np.ndarray], X_test: Union[List, pd.DataFrame, np.ndarray], y_test: Union[List, pd.DataFrame, np.ndarray], - resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation, + resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]] = None, resampling_strategy_args: Optional[Dict[str, Any]] = None, dataset_name: Optional[str] = None, return_only: Optional[bool] = False From 29547ef88a16de7bd2c1e9e1ad6c2b0844664d92 Mon Sep 17 00:00:00 2001 From: chico Date: Mon, 19 Apr 2021 21:29:57 +0200 Subject: [PATCH 07/14] [FIX] hardcoded budget --- autoPyTorch/evaluation/tae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index c8eafccab..8e9885a0d 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -199,7 +199,7 @@ def run_wrapper( ) else: if run_info.budget == 0: - run_info = run_info._replace(budget=100.0) + run_info = run_info._replace(budget=self.pipeline_config[self.budget_type]) elif run_info.budget <= 0 or run_info.budget > 100: raise ValueError('Illegal value for budget, must be >0 and <=100, but is %f' % run_info.budget) From 058c2b59e9ba66644dc89c1ab925f8adc3912c86 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Mon, 19 Apr 2021 22:18:28 +0200 Subject: [PATCH 08/14] fix mypy for pipeline config --- autoPyTorch/evaluation/tae.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/autoPyTorch/evaluation/tae.py b/autoPyTorch/evaluation/tae.py index 8e9885a0d..0ed10c6c6 100644 --- a/autoPyTorch/evaluation/tae.py +++ b/autoPyTorch/evaluation/tae.py @@ -4,6 +4,7 @@ import logging import math import multiprocessing +import os import time import traceback import typing @@ -25,6 +26,7 @@ from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric from autoPyTorch.utils.backend import Backend +from autoPyTorch.utils.common import replace_string_bool_to_bool from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger @@ -144,7 +146,12 @@ def __init__( self.exclude = exclude self.disable_file_output = disable_file_output self.init_params = init_params - self.pipeline_config = pipeline_config + self.pipeline_config: typing.Dict[str, typing.Union[int, str, float]] = dict() + if pipeline_config is None: + pipeline_config = replace_string_bool_to_bool(json.load(open( + os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json')))) + self.pipeline_config.update(pipeline_config) + self.budget_type = pipeline_config['budget_type'] if pipeline_config is not None else budget_type self.logger_port = logger_port if self.logger_port is None: From f6ca6e0c26c105e712e8467ff8556f92ef225a7a Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Tue, 20 Apr 2021 10:12:28 +0200 Subject: [PATCH 09/14] Address Arlinds comment for task type documentation --- autoPyTorch/api/base_task.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 5ab5b6382..b7616e865 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -131,6 +131,8 @@ class BaseTask: required for the chosen resampling strategy. If None, uses the default values provided in DEFAULT_RESAMPLING_PARAMETERS in ```datasets/resampling_strategy.py```. + task_type (str): The task of the experiment as a string. Currently, supported + tasks are 'tabular_classification' and 'tabular_regression' """ def __init__( From 5f7adfdce10ad90bd83d67b252e673649842d9a9 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Tue, 20 Apr 2021 15:58:05 +0200 Subject: [PATCH 10/14] Fix bug in tests --- test/test_api/test_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index a0c1dd368..20ff29bd6 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -435,14 +435,14 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): # directory, but in the temporary directory. assert not os.path.exists(os.path.join(os.getcwd(), '.autoPyTorch')) assert os.path.exists(os.path.join( - backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_1.0', - 'predictions_ensemble_1_1_1.0.npy') + backend.temporary_directory, '.autoPyTorch', 'runs', '1_1_50.0', + 'predictions_ensemble_1_1_50.0.npy') ) model_path = os.path.join(backend.temporary_directory, '.autoPyTorch', - 'runs', '1_1_1.0', - '1.1.1.0.model') + 'runs', '1_1_50.0', + '1.1.50.0.model') # Make sure the dummy model complies with scikit learn # get/set params From cfd728c7a0145661d55b47574c10422972154ccc Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 21 Apr 2021 18:56:39 +0200 Subject: [PATCH 11/14] Change way to get configuration fr using fit_pipeline --- .../tabular/40_advanced/example_single_configuration.py | 6 +----- test/test_api/test_api.py | 7 ++----- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/tabular/40_advanced/example_single_configuration.py b/examples/tabular/40_advanced/example_single_configuration.py index a354b36cc..42f92df06 100644 --- a/examples/tabular/40_advanced/example_single_configuration.py +++ b/examples/tabular/40_advanced/example_single_configuration.py @@ -59,11 +59,7 @@ y_train=y_train, X_test=X_test, y_test=y_test) - dataset_requirements = get_dataset_requirements( - info=dataset.get_required_dataset_info()) - dataset_properties = dataset.get_dataset_properties(dataset_requirements) - configuration = estimator.build_pipeline(dataset_properties).\ - get_hyperparameter_search_space().get_default_configuration() + configuration = estimator.get_search_space(dataset).get_default_configuration() ########################################################################### # Fit the configuration diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 20ff29bd6..be4a88340 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -488,11 +488,8 @@ def test_pipeline_fit(openml_id, y_test=y_test, resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args) - dataset_requirements = get_dataset_requirements( - info=dataset.get_required_dataset_info()) - dataset_properties = dataset.get_dataset_properties(dataset_requirements) - configuration = estimator.build_pipeline(dataset_properties).\ - get_hyperparameter_search_space().get_default_configuration() + + configuration = estimator.get_search_space(dataset).get_default_configuration() pipeline, run_info, run_value, dataset = estimator.fit_pipeline(dataset=dataset, configuration=configuration, run_time_limit_secs=50, From 7f6cddc9f35390b40408058b886d65e2176de376 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 21 Apr 2021 19:07:05 +0200 Subject: [PATCH 12/14] fix flake --- examples/tabular/40_advanced/example_single_configuration.py | 1 - test/test_api/test_api.py | 1 - 2 files changed, 2 deletions(-) diff --git a/examples/tabular/40_advanced/example_single_configuration.py b/examples/tabular/40_advanced/example_single_configuration.py index 42f92df06..f9aa27278 100644 --- a/examples/tabular/40_advanced/example_single_configuration.py +++ b/examples/tabular/40_advanced/example_single_configuration.py @@ -28,7 +28,6 @@ from autoPyTorch.api.tabular_classification import TabularClassificationTask from autoPyTorch.datasets.resampling_strategy import HoldoutValTypes -from autoPyTorch.utils.pipeline import get_dataset_requirements if __name__ == '__main__': diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index be4a88340..20633db1e 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -33,7 +33,6 @@ from autoPyTorch.optimizer.smbo import AutoMLSMBO from autoPyTorch.pipeline.base_pipeline import BasePipeline from autoPyTorch.pipeline.components.training.metrics.metrics import accuracy -from autoPyTorch.utils.pipeline import get_dataset_requirements # Fixtures # ======== From 9cf285c0ec908e633a94c27ae78f8d1f22baba23 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 21 Apr 2021 20:28:47 +0200 Subject: [PATCH 13/14] Trial with --forked --- .github/workflows/pytest.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 2084d7138..e4b226d86 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -29,7 +29,7 @@ jobs: - name: Run tests run: | if [ ${{ matrix.code-cov }} ]; then codecov='--cov=autoPyTorch --cov-report=xml'; fi - python -m pytest --durations=20 --timeout=600 --timeout-method=signal -v $codecov test + python -m pytest --forked --durations=20 --timeout=600 --timeout-method=signal -v $codecov test - name: Check for files left behind by test if: ${{ always() }} run: | From 08c6ef02cb5ac019680256fb52db5f606eaf0853 Mon Sep 17 00:00:00 2001 From: Ravin Kohli Date: Wed, 21 Apr 2021 20:30:42 +0200 Subject: [PATCH 14/14] update setup.py --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 30a9a0697..a3055b41c 100755 --- a/setup.py +++ b/setup.py @@ -48,7 +48,8 @@ "codecov", "pep8", "mypy", - "openml" + "openml", + "pytest-forked" ], "examples": [ "matplotlib",