Skip to content

Commit

Permalink
Search space update (#80)
Browse files Browse the repository at this point in the history
* Added Hyperparameter Search space updates

* added test for search space update

* Added Hyperparameter Search space updates

* added test for search space update

* Added hyperparameter search space updates to network, trainer and improved check for search space updates

* Fix mypy, flake8

* Fix tests and silly mistake in base_pipeline

* Fix flake

* added _cs_updates to dummy component

* fixed indentation and isinstance comment

* fixed silly error

* Addressed comments from fransisco

* added value error for search space updates

* ADD tests for setting range of config space

* fic utils search space update
  • Loading branch information
ravinkohli committed Feb 1, 2021
1 parent 2e7b462 commit ddc0f3d
Show file tree
Hide file tree
Showing 54 changed files with 872 additions and 323 deletions.
15 changes: 13 additions & 2 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from autoPyTorch.pipeline.components.training.metrics.utils import calculate_score, get_metrics
from autoPyTorch.utils.backend import Backend, create
from autoPyTorch.utils.common import FitRequirement, replace_string_bool_to_bool
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import (
PicklableClientLogger,
get_named_client_logger,
Expand Down Expand Up @@ -135,6 +136,7 @@ def __init__(
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> None:
self.seed = seed
self.n_jobs = n_jobs
Expand Down Expand Up @@ -178,6 +180,13 @@ def __init__(

self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]

self.search_space_updates = search_space_updates
if search_space_updates is not None:
if not isinstance(self.search_space_updates,
HyperparameterSearchSpaceUpdates):
raise ValueError("Expected search space updates to be of instance"
" HyperparameterSearchSpaceUpdates got {}".format(type(self.search_space_updates)))

@abstractmethod
def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -252,7 +261,8 @@ def get_search_space(self, dataset: BaseDataset = None) -> ConfigurationSpace:
info=self._get_required_dataset_properties(dataset))
return get_configuration_space(info=dataset.get_dataset_properties(dataset_requirements),
include=self.include_components,
exclude=self.exclude_components)
exclude=self.exclude_components,
search_space_updates=self.search_space_updates)
raise Exception("No search space initialised and no dataset passed. "
"Can't create default search space without the dataset")

Expand Down Expand Up @@ -816,7 +826,8 @@ def search(
pipeline_config={**self.pipeline_options, **budget_config},
ensemble_callback=proc_ensemble,
logger_port=self._logger_port,
start_num_run=num_run
start_num_run=num_run,
search_space_updates=self.search_space_updates
)
try:
self.run_history, self.trajectory, budget_type = \
Expand Down
3 changes: 3 additions & 0 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates


class TabularClassificationTask(BaseTask):
Expand Down Expand Up @@ -52,6 +53,7 @@ def __init__(
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
):
super().__init__(
seed=seed,
Expand All @@ -67,6 +69,7 @@ def __init__(
include_components=include_components,
exclude_components=exclude_components,
backend=backend,
search_space_updates=search_space_updates
)
self.task_type = TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION]

Expand Down
10 changes: 8 additions & 2 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
get_metrics,
)
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger
from autoPyTorch.utils.pipeline import get_dataset_requirements

Expand Down Expand Up @@ -200,7 +201,9 @@ def __init__(self, backend: Backend,
disable_file_output: Union[bool, List[str]] = False,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
all_supported_metrics: bool = True) -> None:
all_supported_metrics: bool = True,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
) -> None:

self.starttime = time.time()

Expand All @@ -218,6 +221,7 @@ def __init__(self, backend: Backend,

self.include = include
self.exclude = exclude
self.search_space_updates = search_space_updates

self.X_train, self.y_train = self.datamanager.train_tensors

Expand Down Expand Up @@ -324,6 +328,7 @@ def __init__(self, backend: Backend,
self.pipelines: Optional[List[BaseEstimator]] = None
self.pipeline: Optional[BaseEstimator] = None
self.logger.debug("Fit dictionary in Abstract evaluator: {}".format(self.fit_dictionary))
self.logger.debug("Search space updates :{}".format(self.search_space_updates))

def _get_pipeline(self) -> BaseEstimator:
assert self.pipeline_class is not None, "Can't return pipeline, pipeline_class not initialised"
Expand All @@ -337,7 +342,8 @@ def _get_pipeline(self) -> BaseEstimator:
random_state=np.random.RandomState(self.seed),
include=self.include,
exclude=self.exclude,
init_params=self._init_params)
init_params=self._init_params,
search_space_updates=self.search_space_updates)
elif isinstance(self.configuration, str):
pipeline = self.pipeline_class(config=self.configuration,
dataset_properties=self.dataset_properties,
Expand Down
8 changes: 7 additions & 1 deletion autoPyTorch/evaluation/tae.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from autoPyTorch.evaluation.utils import empty_queue, extract_learning_curve, read_queue
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import PicklableClientLogger, get_named_client_logger


Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
ta: typing.Optional[typing.Callable] = None,
logger_port: int = None,
all_supported_metrics: bool = True,
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
):

eval_function = autoPyTorch.evaluation.train_evaluator.eval_function
Expand Down Expand Up @@ -164,6 +166,8 @@ def __init__(
self.resampling_strategy = dm.resampling_strategy
self.resampling_strategy_args = dm.resampling_strategy_args

self.search_space_updates = search_space_updates

def run_wrapper(
self,
run_info: RunInfo,
Expand Down Expand Up @@ -250,6 +254,7 @@ def run(
else:
num_run = config.config_id + self.initial_num_run

self.logger.debug("Search space updates: {}".format(self.search_space_updates))
obj_kwargs = dict(
queue=queue,
config=config,
Expand All @@ -267,7 +272,8 @@ def run(
budget_type=self.budget_type,
pipeline_config=self.pipeline_config,
logger_port=self.logger_port,
all_supported_metrics=self.all_supported_metrics
all_supported_metrics=self.all_supported_metrics,
search_space_updates=self.search_space_updates
)

info: typing.Optional[typing.List[RunValue]]
Expand Down
12 changes: 9 additions & 3 deletions autoPyTorch/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from autoPyTorch.evaluation.utils import subsampler
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates

__all__ = ['TrainEvaluator', 'eval_function']

Expand Down Expand Up @@ -48,7 +49,8 @@ def __init__(self, backend: Backend, queue: Queue,
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
keep_models: Optional[bool] = None,
all_supported_metrics: bool = True) -> None:
all_supported_metrics: bool = True,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None) -> None:
super().__init__(
backend=backend,
queue=queue,
Expand All @@ -65,7 +67,8 @@ def __init__(self, backend: Backend, queue: Queue,
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config
pipeline_config=pipeline_config,
search_space_updates=search_space_updates
)

self.splits = self.datamanager.splits
Expand All @@ -77,6 +80,7 @@ def __init__(self, backend: Backend, queue: Queue,
self.pipelines: List[Optional[BaseEstimator]] = [None] * self.num_folds
self.indices: List[Optional[Tuple[Union[np.ndarray, List], Union[np.ndarray, List]]]] = [None] * self.num_folds

self.logger.debug("Search space updates :{}".format(self.search_space_updates))
self.keep_models = keep_models

def fit_predict_and_loss(self) -> None:
Expand Down Expand Up @@ -320,6 +324,7 @@ def eval_function(
init_params: Optional[Dict[str, Any]] = None,
logger_port: Optional[int] = None,
all_supported_metrics: bool = True,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
instance: str = None,
) -> None:
evaluator = TrainEvaluator(
Expand All @@ -338,6 +343,7 @@ def eval_function(
budget_type=budget_type,
logger_port=logger_port,
all_supported_metrics=all_supported_metrics,
pipeline_config=pipeline_config
pipeline_config=pipeline_config,
search_space_updates=search_space_updates
)
evaluator.fit_predict_and_loss()
14 changes: 8 additions & 6 deletions autoPyTorch/optimizer/smbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@
from smac.tae.serial_runner import SerialRunner
from smac.utils.io.traj_logging import TrajEntry

# TODO: Enable when merged Ensemble
# from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
DEFAULT_RESAMPLING_PARAMETERS,
HoldoutValTypes,
)
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
from autoPyTorch.pipeline.components.training.metrics.base import autoPyTorchMetric
from autoPyTorch.utils.backend import Backend
from autoPyTorch.utils.hyperparameter_search_space_update import HyperparameterSearchSpaceUpdates
from autoPyTorch.utils.logging_ import get_named_client_logger
from autoPyTorch.utils.stopwatch import StopWatch

Expand Down Expand Up @@ -101,10 +101,9 @@ def __init__(self,
smac_scenario_args: typing.Optional[typing.Dict[str, typing.Any]] = None,
get_smac_object_callback: typing.Optional[typing.Callable] = None,
all_supported_metrics: bool = True,
# TODO: Re-enable when ensemble merged
# ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
ensemble_callback: typing.Any = None,
ensemble_callback: typing.Optional[EnsembleBuilderManager] = None,
logger_port: typing.Optional[int] = None,
search_space_updates: typing.Optional[HyperparameterSearchSpaceUpdates] = None
):
"""
Interface to SMAC. This method calls the SMAC optimize method, and allows
Expand Down Expand Up @@ -194,6 +193,8 @@ def __init__(self,

self.ensemble_callback = ensemble_callback

self.search_space_updates = search_space_updates

dataset_name_ = "" if dataset_name is None else dataset_name
if logger_port is None:
self.logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT
Expand Down Expand Up @@ -254,7 +255,8 @@ def run_smbo(self, func: typing.Optional[typing.Callable] = None
ta=func,
logger_port=self.logger_port,
all_supported_metrics=self.all_supported_metrics,
pipeline_config=self.pipeline_config
pipeline_config=self.pipeline_config,
search_space_updates=self.search_space_updates
)
ta = ExecuteTaFuncWithQueue
self.logger.info("Created TA")
Expand Down

0 comments on commit ddc0f3d

Please sign in to comment.