diff --git a/autoPyTorch/api/base_task.py b/autoPyTorch/api/base_task.py index 9e59601a8..320f9a8d6 100644 --- a/autoPyTorch/api/base_task.py +++ b/autoPyTorch/api/base_task.py @@ -183,6 +183,9 @@ def __init__( self.trajectory: Optional[List] = None self.dataset_name: Optional[str] = None self.cv_models_: Dict = {} + self.precision: Optional[int] = None + self.opt_metric: Optional[str] = None + self.dataset: Optional[BaseDataset] = None # By default try to use the TCP logging port or get a new port self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT @@ -412,6 +415,25 @@ def _close_dask_client(self) -> None: self._is_dask_client_internally_created = False del self._is_dask_client_internally_created + def _cleanup(self) -> None: + """ + + Closes the different servers created during api search. + + Returns: + None + """ + if self._logger is not None: + self._logger.info("Closing the dask infrastructure") + self._close_dask_client() + self._logger.info("Finished closing the dask infrastructure") + + # Clean up the logger + self._logger.info("Starting to clean up the logger") + self._clean_logger() + else: + self._close_dask_client() + def _load_models(self) -> bool: """ @@ -783,7 +805,7 @@ def _search( metrics supporting current task will be calculated for each pipeline and results will be available via cv_results precision (int), (default=32): Numeric precision used when loading - ensemble data. Can be either '16', '32' or '64'. + ensemble data. Can be either 16, 32 or 64. disable_file_output (Union[bool, List]): load_models (bool), (default=True): Whether to load the models after fitting AutoPyTorch. @@ -910,6 +932,8 @@ def _search( self._stopwatch.stop_task(traditional_task_name) # ============> Starting ensemble + self.precision = precision + self.opt_metric = optimize_metric elapsed_time = self._stopwatch.wall_elapsed(self.dataset_name) time_left_for_ensembles = max(0, total_walltime_limit - elapsed_time) proc_ensemble = None @@ -926,28 +950,12 @@ def _search( self._logger.info("Starting ensemble") ensemble_task_name = 'ensemble' self._stopwatch.start_task(ensemble_task_name) - proc_ensemble = EnsembleBuilderManager( - start_time=time.time(), - time_left_for_ensembles=time_left_for_ensembles, - backend=copy.deepcopy(self._backend), - dataset_name=str(dataset.dataset_name), - output_type=STRING_TO_OUTPUT_TYPES[dataset.output_type], - task_type=STRING_TO_TASK_TYPES[self.task_type], - metrics=[self._metric], - opt_metric=optimize_metric, - ensemble_size=self.ensemble_size, - ensemble_nbest=self.ensemble_nbest, - max_models_on_disc=self.max_models_on_disc, - seed=self.seed, - max_iterations=None, - read_at_most=sys.maxsize, - ensemble_memory_limit=self._memory_limit, - random_state=self.seed, - precision=precision, - logger_port=self._logger_port, - pynisher_context=self._multiprocessing_context, - ) - self._stopwatch.stop_task(ensemble_task_name) + proc_ensemble = self._init_ensemble_builder(time_left_for_ensembles=time_left_for_ensembles, + ensemble_size=self.ensemble_size, + ensemble_nbest=self.ensemble_nbest, + precision=precision, + optimize_metric=self.opt_metric + ) # ==> Run SMAC smac_task_name: str = 'runSMAC' @@ -1028,18 +1036,12 @@ def _search( pd.DataFrame(self.ensemble_performance_history).to_json( os.path.join(self._backend.internals_directory, 'ensemble_history.json')) - self._logger.info("Closing the dask infrastructure") - self._close_dask_client() - self._logger.info("Finished closing the dask infrastructure") - if load_models: self._logger.info("Loading models...") self._load_models() self._logger.info("Finished loading models...") - # Clean up the logger - self._logger.info("Starting to clean up the logger") - self._clean_logger() + self._cleanup() return self @@ -1114,7 +1116,7 @@ def refit( # the ordering of the data. fit_and_suppress_warnings(self._logger, model, X, y=None) - self._clean_logger() + self._cleanup() return self @@ -1179,9 +1181,139 @@ def fit(self, fit_and_suppress_warnings(self._logger, pipeline, X, y=None) - self._clean_logger() + self._cleanup() return pipeline + def fit_ensemble( + self, + ensemble_nbest: int = 50, + ensemble_size: int = 50, + precision: int = 32, + load_models: bool = True + ) -> 'BaseTask': + """ + Enables post-hoc fitting of the ensemble after the `search()` + method is finished. This method creates an ensemble using all + the models stored on disk during the smbo run + Args: + ensemble_nbest (Optional[int]): + only consider the ensemble_nbest models to build the ensemble. + If None, uses the value stored in class attribute `ensemble_nbest`. + ensemble_size (int) (default=50): + Number of models added to the ensemble built by + Ensemble selection from libraries of models. + Models are drawn with replacement. + precision (int), (default=32): Numeric precision used when loading + ensemble data. Can be either 16, 32 or 64. + + Returns: + self + """ + # Make sure that input is valid + if self.dataset is None or self.opt_metric is None: + raise ValueError("fit_ensemble() can only be called after `search()`. " + "Please call the `search()` method of {} prior to " + "fit_ensemble().".format(self.__class__.__name__)) + + if self._logger is None: + self._logger = self._get_logger(self.dataset.dataset_name) + + # Create a client if needed + if self._dask_client is None: + self._create_dask_client() + else: + self._is_dask_client_internally_created = False + + manager = self._init_ensemble_builder( + time_left_for_ensembles=self._time_for_task, + optimize_metric=self.opt_metric, + precision=precision, + ensemble_size=ensemble_size, + ensemble_nbest=ensemble_nbest, + ) + + manager.build_ensemble(self._dask_client) + future = manager.futures.pop() + result = future.result() + if result is None: + raise ValueError("Errors occurred while building the ensemble - please" + " check the log file and command line output for error messages.") + self.ensemble_performance_history, _, _, _ = result + + if load_models: + self._load_models() + self._cleanup() + return self + + def _init_ensemble_builder( + self, + time_left_for_ensembles: float, + optimize_metric: str, + ensemble_nbest: int, + ensemble_size: int, + precision: int = 32, + ) -> EnsembleBuilderManager: + """ + Initializes an `EnsembleBuilderManager`. + + Args: + time_left_for_ensembles (float): + Time (in seconds) allocated to building the ensemble + optimize_metric (str): + Name of the metric to optimize the ensemble. + ensemble_nbest (int): + only consider the ensemble_nbest models to build the ensemble. + ensemble_size (int): + Number of models added to the ensemble built by + Ensemble selection from libraries of models. + Models are drawn with replacement. + precision (int), (default=32): Numeric precision used when loading + ensemble data. Can be either 16, 32 or 64. + + Returns: + EnsembleBuilderManager + + """ + if self._logger is None: + raise ValueError("logger should be initialized to fit ensemble") + if self.dataset is None: + raise ValueError("ensemble can only be initialised after or during `search()`. " + "Please call the `search()` method of {}.".format(self.__class__.__name__)) + + self._logger.info("Starting ensemble") + ensemble_task_name = 'ensemble' + self._stopwatch.start_task(ensemble_task_name) + + # Use the current thread to start the ensemble builder process + # The function ensemble_builder_process will internally create a ensemble + # builder in the provide dask client + required_dataset_properties = {'task_type': self.task_type, + 'output_type': self.dataset.output_type} + proc_ensemble = EnsembleBuilderManager( + start_time=time.time(), + time_left_for_ensembles=time_left_for_ensembles, + backend=copy.deepcopy(self._backend), + dataset_name=str(self.dataset.dataset_name), + output_type=STRING_TO_OUTPUT_TYPES[self.dataset.output_type], + task_type=STRING_TO_TASK_TYPES[self.task_type], + metrics=[self._metric] if self._metric is not None else get_metrics( + dataset_properties=required_dataset_properties, names=[optimize_metric]), + opt_metric=optimize_metric, + ensemble_size=ensemble_size, + ensemble_nbest=ensemble_nbest, + max_models_on_disc=self.max_models_on_disc, + seed=self.seed, + max_iterations=None, + read_at_most=sys.maxsize, + ensemble_memory_limit=self._memory_limit, + random_state=self.seed, + precision=precision, + logger_port=self._logger_port, + pynisher_context=self._multiprocessing_context, + ) + self._stopwatch.stop_task(ensemble_task_name) + return proc_ensemble + def predict( self, X_test: np.ndarray, @@ -1230,7 +1362,7 @@ def predict( predictions = self.ensemble_.predict(all_predictions) - self._clean_logger() + self._cleanup() return predictions @@ -1267,10 +1399,7 @@ def __getstate__(self) -> Dict[str, Any]: return self.__dict__ def __del__(self) -> None: - # Clean up the logger - self._clean_logger() - - self._close_dask_client() + self._cleanup() # When a multiprocessing work is done, the # objects are deleted. We don't want to delete run areas diff --git a/autoPyTorch/datasets/base_dataset.py b/autoPyTorch/datasets/base_dataset.py index 15a6dedf9..fee7210c2 100644 --- a/autoPyTorch/datasets/base_dataset.py +++ b/autoPyTorch/datasets/base_dataset.py @@ -109,10 +109,7 @@ def __init__( val_transforms (Optional[torchvision.transforms.Compose]): Additional Transforms to be applied to the validation/test data """ - self.dataset_name = dataset_name - - if self.dataset_name is None: - self.dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + self.dataset_name: str = dataset_name if dataset_name is not None else str(uuid.uuid1(clock_seq=os.getpid())) if not hasattr(train_tensors[0], 'shape'): type_check(train_tensors, val_tensors) diff --git a/test/test_api/.tmp_api/traditional_run_history.json b/test/test_api/.tmp_api/traditional_run_history.json index f7a106c82..ef8c865cc 100644 --- a/test/test_api/.tmp_api/traditional_run_history.json +++ b/test/test_api/.tmp_api/traditional_run_history.json @@ -8,13 +8,13 @@ 0.0 ], [ - 0.20467836257309946, - 48.634921073913574, + 0.12121212121212122, + 6.456650972366333, { "__enum__": "StatusType.SUCCESS" }, - 0.0, - 0.0, + 1623939184.489662, + 1623939190.946313, { "trainer_configuration": { "num_rounds": 10000, @@ -28,46 +28,14 @@ }, "configuration_origin": "traditional", "opt_loss": { - "accuracy": 0.20467836257309946, - "balanced_accuracy": 0.20607553366174058, - "roc_auc": 0.0847016967706623, - "average_precision": 0.0827746781030202, - "log_loss": 0.4155085084208271, - "precision": 0.15492957746478875, - "precision_macro": 0.19746478873239437, - "precision_micro": 0.20467836257309946, - "precision_weighted": 0.19821102050901895, - "recall": 0.2857142857142857, - "recall_macro": 0.20607553366174058, - "recall_micro": 0.20467836257309946, - "recall_weighted": 0.20467836257309946, - "f1": 0.22580645161290325, - "f1_macro": 0.2064861135069863, - "f1_micro": 0.20467836257309946, - "f1_weighted": 0.2061471602068825 - }, - "duration": 30.787471771240234, + "accuracy": 0.12121212121212122 + }, + "duration": 6.225315809249878, "num_run": 2, "train_loss": { - "accuracy": 0.0, - "balanced_accuracy": 0.0, - "roc_auc": 0.0, - "average_precision": 0.0, - "log_loss": 0.1163032455208329, - "precision": 0.0, - "precision_macro": 0.0, - "precision_micro": 0.0, - "precision_weighted": 0.0, - "recall": 0.0, - "recall_macro": 0.0, - "recall_micro": 0.0, - "recall_weighted": 0.0, - "f1": 0.0, - "f1_macro": 0.0, - "f1_micro": 0.0, - "f1_weighted": 0.0 - }, - "test_loss": 0.138728323699422 + "accuracy": 0.0 + }, + "test_loss": 0.040000000000000036 } ] ], @@ -79,61 +47,28 @@ 0.0 ], [ - 0.14619883040935677, - 24.41903591156006, + 0.0757575757575758, + 25.649624824523926, { "__enum__": "StatusType.SUCCESS" }, - 0.0, - 0.0, + 1623939184.489662, + 1623939210.1392868, { "trainer_configuration": { "iterations": 10000, - "learning_rate": 0.1, - "eval_metric": "Accuracy" + "learning_rate": 0.1 }, "configuration_origin": "traditional", "opt_loss": { - "accuracy": 0.14619883040935677, - "balanced_accuracy": 0.14573070607553373, - "roc_auc": 0.09530651340996166, - "average_precision": 0.09777406254428278, - "log_loss": 0.5589205214851781, - "precision": 0.1685393258426966, - "precision_macro": 0.1452452726774458, - "precision_micro": 0.14619883040935677, - "precision_weighted": 0.1448366050780555, - "recall": 0.11904761904761907, - "recall_macro": 0.14573070607553373, - "recall_micro": 0.14619883040935677, - "recall_weighted": 0.14619883040935677, - "f1": 0.1445086705202312, - "f1_macro": 0.14621883230153576, - "f1_micro": 0.14619883040935677, - "f1_weighted": 0.14624883513980425 - }, - "duration": 9.664803266525269, + "accuracy": 0.0757575757575758 + }, + "duration": 25.479766845703125, "num_run": 3, "train_loss": { - "accuracy": 0.138728323699422, - "balanced_accuracy": 0.13000374748748, - "roc_auc": 0.05154498688379383, - "average_precision": 0.05783407475676716, - "log_loss": 0.5370512441920408, - "precision": 0.21468926553672318, - "precision_macro": 0.13693043158492957, - "precision_micro": 0.138728323699422, - "precision_weighted": 0.1261430788979756, - "recall": 0.06711409395973156, - "recall_macro": 0.13000374748748, - "recall_micro": 0.138728323699422, - "recall_weighted": 0.138728323699422, - "f1": 0.1472392638036809, - "f1_macro": 0.13919340239364375, - "f1_micro": 0.138728323699422, - "f1_weighted": 0.13807721352751146 - }, - "test_loss": 0.12716763005780352 + "accuracy": 0.06716417910447758 + }, + "test_loss": 0.06999999999999995 } ] ], @@ -145,13 +80,13 @@ 0.0 ], [ - 0.14035087719298245, - 18.845818758010864, + 0.10606060606060608, + 6.964690923690796, { "__enum__": "StatusType.SUCCESS" }, - 0.0, - 0.0, + 1623939184.489662, + 1623939191.4543529, { "trainer_configuration": { "n_estimators": 300, @@ -159,46 +94,113 @@ }, "configuration_origin": "traditional", "opt_loss": { - "accuracy": 0.14035087719298245, - "balanced_accuracy": 0.14142036124794743, - "roc_auc": 0.08401751505199773, - "average_precision": 0.0788213312884698, - "log_loss": 0.37833770927673543, - "precision": 0.09459459459459463, - "precision_macro": 0.13492616327667872, - "precision_micro": 0.14035087719298245, - "precision_weighted": 0.1356337346570663, - "recall": 0.20238095238095233, - "recall_macro": 0.14142036124794743, - "recall_micro": 0.14035087719298245, - "recall_weighted": 0.14035087719298245, - "f1": 0.15189873417721522, - "f1_macro": 0.14116675839295545, - "f1_micro": 0.14035087719298245, - "f1_weighted": 0.1409784781160387 - }, - "duration": 4.936332941055298, + "accuracy": 0.10606060606060608 + }, + "duration": 6.7608888149261475, "num_run": 4, "train_loss": { - "accuracy": 0.0, - "balanced_accuracy": 0.0, - "roc_auc": 0.0, - "average_precision": 2.220446049250313e-16, - "log_loss": 0.0899028721860357, - "precision": 0.0, - "precision_macro": 0.0, - "precision_micro": 0.0, - "precision_weighted": 0.0, - "recall": 0.0, - "recall_macro": 0.0, - "recall_micro": 0.0, - "recall_weighted": 0.0, - "f1": 0.0, - "f1_macro": 0.0, - "f1_micro": 0.0, - "f1_weighted": 0.0 - }, - "test_loss": 0.1445086705202312 + "accuracy": 0.0 + }, + "test_loss": 0.03500000000000003 + } + ] + ], + [ + [ + 4, + null, + 1, + 0.0 + ], + [ + 0.12121212121212122, + 5.65510892868042, + { + "__enum__": "StatusType.SUCCESS" + }, + 1623939184.489662, + 1623939190.1447709, + { + "trainer_configuration": { + "n_estimators": 300, + "n_jobs": -1 + }, + "configuration_origin": "traditional", + "opt_loss": { + "accuracy": 0.12121212121212122 + }, + "duration": 5.463550090789795, + "num_run": 5, + "train_loss": { + "accuracy": 0.0 + }, + "test_loss": 0.040000000000000036 + } + ] + ], + [ + [ + 5, + null, + 1, + 0.0 + ], + [ + 0.10606060606060608, + 1.6963858604431152, + { + "__enum__": "StatusType.SUCCESS" + }, + 1623939184.489662, + 1623939186.1860478, + { + "trainer_configuration": { + "C": 1.0, + "degree": 3 + }, + "configuration_origin": "traditional", + "opt_loss": { + "accuracy": 0.10606060606060608 + }, + "duration": 1.5256588459014893, + "num_run": 6, + "train_loss": { + "accuracy": 0.08208955223880599 + }, + "test_loss": 0.08999999999999997 + } + ] + ], + [ + [ + 6, + null, + 1, + 0.0 + ], + [ + 0.0757575757575758, + 2.1201720237731934, + { + "__enum__": "StatusType.SUCCESS" + }, + 1623939184.489662, + 1623939186.609834, + { + "trainer_configuration": { + "weights": "uniform", + "n_jobs": -1 + }, + "configuration_origin": "traditional", + "opt_loss": { + "accuracy": 0.0757575757575758 + }, + "duration": 1.9614458084106445, + "num_run": 7, + "train_loss": { + "accuracy": 0.07462686567164178 + }, + "test_loss": 0.07499999999999996 } ] ] @@ -206,16 +208,28 @@ "config_origins": {}, "configs": { "1": { - "model_trainer:__choice__": "tabular_classifier", - "model_trainer:tabular_classifier:classifier": "lgb" + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "lgb" }, "2": { - "model_trainer:__choice__": "tabular_classifier", - "model_trainer:tabular_classifier:classifier": "catboost" + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "catboost" }, "3": { - "model_trainer:__choice__": "tabular_classifier", - "model_trainer:tabular_classifier:classifier": "random_forest" + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "random_forest" + }, + "4": { + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "extra_trees" + }, + "5": { + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "svm" + }, + "6": { + "model_trainer:__choice__": "tabular_traditional_model", + "model_trainer:tabular_traditional_model:traditional_learner": "knn" } } } \ No newline at end of file diff --git a/test/test_api/test_api.py b/test/test_api/test_api.py index 5f670e59d..aa6095b41 100644 --- a/test/test_api/test_api.py +++ b/test/test_api/test_api.py @@ -494,8 +494,7 @@ def test_do_dummy_prediction(dask_client, fit_dictionary_tabular): with open(model_path, 'rb') as model_handler: clone(pickle.load(model_handler)) - estimator._close_dask_client() - estimator._clean_logger() + estimator._cleanup() del estimator @@ -694,8 +693,7 @@ def test_do_traditional_pipeline(fit_dictionary_tabular): if not at_least_one_model_checked: pytest.fail("Not even one single traditional pipeline was fitted") - estimator._close_dask_client() - estimator._clean_logger() + estimator._cleanup() del estimator @@ -714,3 +712,66 @@ def test_build_pipeline(api_type, fit_dictionary_tabular): pipeline = api.build_pipeline(fit_dictionary_tabular['dataset_properties']) assert isinstance(pipeline, BaseEstimator) assert len(pipeline.steps) > 0 + + +@unittest.mock.patch('autoPyTorch.evaluation.train_evaluator.eval_function', + new=dummy_eval_function) +@pytest.mark.parametrize('dataset_name', ('iris',)) +def test_fit_ensemble(backend, n_samples, dataset_name): + # Get the data and check that contents of data-manager make sense + X, y = sklearn.datasets.fetch_openml( + name=dataset_name, + return_X_y=True, as_frame=True + ) + X, y = X.iloc[:n_samples], y.iloc[:n_samples] + + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, random_state=42) + + # Search for a good configuration + estimator = TabularClassificationTask( + backend=backend, + seed=42, + ensemble_size=0, + ) + + with unittest.mock.patch.object(estimator, '_do_dummy_prediction', new=dummy_do_dummy_prediction): + estimator.search( + X_train=X_train, y_train=y_train, + X_test=X_test, y_test=y_test, + optimize_metric='accuracy', + total_walltime_limit=40, + func_eval_time_limit_secs=10, + enable_traditional_pipeline=False, + ) + + estimator.fit_ensemble(ensemble_size=2) + assert isinstance(estimator.ensemble_performance_history, list) + assert 'train_accuracy' in estimator.ensemble_performance_history[0] + assert 'test_accuracy' in estimator.ensemble_performance_history[0] + + assert os.path.exists(os.path.join(estimator._backend.internals_directory, 'ensembles')) + assert len(os.listdir(os.path.join(estimator._backend.internals_directory, 'ensembles'))) > 0 + assert any(['.ensemble' in file for file in os.listdir(os.path.join( + estimator._backend.internals_directory, 'ensembles'))]) + assert any(['ensemble_' or '_ensemble.npy' in os.listdir(estimator._backend.internals_directory)]) + + preds = estimator.predict(X_test) + assert isinstance(preds, np.ndarray) + + assert len(estimator.ensemble_performance_history) > 0 + + +@pytest.mark.parametrize('dataset_name', ('iris',)) +def test_fit_ensemble_failure(backend, n_samples, dataset_name): + # Search for a good configuration + estimator = TabularClassificationTask( + backend=backend, + seed=42, + ensemble_size=0, + ) + + with pytest.raises(ValueError, + match=r"fit_ensemble\(\) can only be called after `search\(\)`. " + r"Please call the `search\(\)` method of [A-Z|a-z]+ prior to fit_ensemble\(\)."): + estimator.fit_ensemble(ensemble_size=2)