Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handling Input to auto pytorch #89

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 33 additions & 35 deletions autoPyTorch/api/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,21 +122,24 @@ class BaseTask:
"""

def __init__(
self,
seed: int = 1,
n_jobs: int = 1,
logging_config: Optional[Dict] = None,
ensemble_size: int = 50,
ensemble_nbest: int = 50,
max_models_on_disc: int = 50,
temporary_directory: Optional[str] = None,
output_directory: Optional[str] = None,
delete_tmp_folder_after_terminate: bool = True,
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
self,
seed: int = 1,
n_jobs: int = 1,
logging_config: Optional[Dict] = None,
ensemble_size: int = 50,
ensemble_nbest: int = 50,
max_models_on_disc: int = 50,
temporary_directory: Optional[str] = None,
output_directory: Optional[str] = None,
delete_tmp_folder_after_terminate: bool = True,
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
backend: Optional[Backend] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
task_type: Optional[str] = None
) -> None:
self.seed = seed
self.n_jobs = n_jobs
Expand All @@ -157,14 +160,14 @@ def __init__(
delete_tmp_folder_after_terminate=delete_tmp_folder_after_terminate,
delete_output_folder_after_terminate=delete_output_folder_after_terminate,
)
self.task_type = task_type
self._stopwatch = StopWatch()

self.pipeline_options = replace_string_bool_to_bool(json.load(open(
os.path.join(os.path.dirname(__file__), '../configs/default_pipeline_options.json'))))

self.search_space: Optional[ConfigurationSpace] = None
self._dataset_requirements: Optional[List[FitRequirement]] = None
self.task_type: Optional[str] = None
self._metric: Optional[autoPyTorchMetric] = None
self._logger: Optional[PicklableClientLogger] = None
self.run_history: Optional[RunHistory] = None
Expand All @@ -176,7 +179,8 @@ def __init__(
self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT

# Store the resampling strategy from the dataset, to load models as needed
self.resampling_strategy = None # type: Optional[Union[CrossValTypes, HoldoutValTypes]]
self.resampling_strategy = resampling_strategy
self.resampling_strategy_args = resampling_strategy_args

self.stop_logging_server = None # type: Optional[multiprocessing.synchronize.Event]

Expand Down Expand Up @@ -287,7 +291,7 @@ def _get_logger(self, name: str) -> PicklableClientLogger:
output_dir=self._backend.temporary_directory,
)

# As Auto-sklearn works with distributed process,
# As AutoPyTorch works with distributed process,
# we implement a logger server that can receive tcp
# pickled messages. They are unpickled and processed locally
# under the above logging configuration setting
Expand Down Expand Up @@ -398,20 +402,16 @@ def _close_dask_client(self) -> None:
self._is_dask_client_internally_created = False
del self._is_dask_client_internally_created

def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, HoldoutValTypes]]
) -> bool:
def _load_models(self) -> bool:

"""
Loads the models saved in the temporary directory
during the smac run and the final ensemble created
Args:
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]): resampling strategy used to split the data
and to validate the performance of a candidate pipeline

Returns:
None
"""
if resampling_strategy is None:
if self.resampling_strategy is None:
raise ValueError("Resampling strategy is needed to determine what models to load")
self.ensemble_ = self._backend.load_ensemble(self.seed)

Expand All @@ -422,10 +422,10 @@ def _load_models(self, resampling_strategy: Optional[Union[CrossValTypes, Holdou
if self.ensemble_:
identifiers = self.ensemble_.get_selected_model_identifiers()
self.models_ = self._backend.load_models_by_identifiers(identifiers)
if isinstance(resampling_strategy, CrossValTypes):
if isinstance(self.resampling_strategy, CrossValTypes):
self.cv_models_ = self._backend.load_cv_models_by_identifiers(identifiers)

if isinstance(resampling_strategy, CrossValTypes):
if isinstance(self.resampling_strategy, CrossValTypes):
if len(self.cv_models_) == 0:
raise ValueError('No models fitted!')

Expand Down Expand Up @@ -610,10 +610,10 @@ def _do_traditional_prediction(self, num_run: int, time_for_traditional: int) ->
)
return num_run

def search(
def _search(
self,
dataset: BaseDataset,
optimize_metric: str,
dataset: BaseDataset,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
Expand All @@ -638,6 +638,7 @@ def search(
The argument that will provide the dataset splits. It is
a subclass of the base dataset object which can
generate the splits based on different restrictions.
Providing X_train, y_train and dataset together is not supported.
optimize_metric (str): name of the metric that is used to
evaluate a pipeline.
budget_type (Optional[str]):
Expand Down Expand Up @@ -692,6 +693,7 @@ def search(
self

"""

if self.task_type != dataset.task_type:
raise ValueError("Incompatible dataset entered for current task,"
"expected dataset to have task type :{} got "
Expand All @@ -705,7 +707,6 @@ def search(
dataset_properties = dataset.get_dataset_properties(dataset_requirements)
self._stopwatch.start_task(experiment_task_name)
self.dataset_name = dataset.dataset_name
self.resampling_strategy = dataset.resampling_strategy
self._logger = self._get_logger(self.dataset_name)
self._all_supported_metrics = all_supported_metrics
self._disable_file_output = disable_file_output
Expand Down Expand Up @@ -869,7 +870,7 @@ def search(

if load_models:
self._logger.info("Loading models...")
self._load_models(dataset.resampling_strategy)
self._load_models()
self._logger.info("Finished loading models...")

# Clean up the logger
Expand Down Expand Up @@ -927,7 +928,7 @@ def refit(
})
X.update({**self.pipeline_options, **budget_config})
if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None:
self._load_models(dataset.resampling_strategy)
self._load_models()

# Refit is not applicable when ensemble_size is set to zero.
if self.ensemble_ is None:
Expand Down Expand Up @@ -1025,7 +1026,7 @@ def predict(
if self._logger is None:
self._logger = self._get_logger("Predict-Logger")

if self.ensemble_ is None and not self._load_models(self.resampling_strategy):
if self.ensemble_ is None and not self._load_models():
raise ValueError("No ensemble found. Either fit has not yet "
"been called or no ensemble was fitted")

Expand Down Expand Up @@ -1084,9 +1085,6 @@ def score(
Returns:
Dict[str, float]: Value of the evaluation metric calculated on the test set.
"""
if isinstance(y_test, pd.Series):
y_test = y_test.to_numpy(dtype=np.float)

if self._metric is None:
raise ValueError("No metric found. Either fit/search has not been called yet "
"or AutoPyTorch failed to infer a metric from the dataset ")
Expand Down
171 changes: 168 additions & 3 deletions autoPyTorch/api/tabular_classification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np

import pandas as pd

from autoPyTorch.api.base_task import BaseTask
from autoPyTorch.constants import (
TABULAR_CLASSIFICATION,
TASK_TYPES_TO_STRING,
)
from autoPyTorch.data.tabular_validator import TabularInputValidator
from autoPyTorch.datasets.base_dataset import BaseDataset
from autoPyTorch.datasets.resampling_strategy import (
CrossValTypes,
HoldoutValTypes,
)
from autoPyTorch.datasets.tabular_dataset import TabularDataset
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
from autoPyTorch.utils.backend import Backend
Expand Down Expand Up @@ -52,6 +61,8 @@ def __init__(
delete_output_folder_after_terminate: bool = True,
include_components: Optional[Dict] = None,
exclude_components: Optional[Dict] = None,
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
resampling_strategy_args: Optional[Dict[str, Any]] = None,
backend: Optional[Backend] = None,
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None
):
Expand All @@ -69,9 +80,15 @@ def __init__(
include_components=include_components,
exclude_components=exclude_components,
backend=backend,
search_space_updates=search_space_updates
resampling_strategy=resampling_strategy,
resampling_strategy_args=resampling_strategy_args,
search_space_updates=search_space_updates,
task_type=TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION],
)
self.task_type = TASK_TYPES_TO_STRING[TABULAR_CLASSIFICATION]

# Create a validator object to make sure that the data provided by
# the user matches the autopytorch requirements
self.InputValidator = TabularInputValidator(is_classification=True)
franchuterivera marked this conversation as resolved.
Show resolved Hide resolved

def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, Any]:
if not isinstance(dataset, TabularDataset):
Expand All @@ -86,3 +103,151 @@ def _get_required_dataset_properties(self, dataset: BaseDataset) -> Dict[str, An

def build_pipeline(self, dataset_properties: Dict[str, Any]) -> TabularClassificationPipeline:
return TabularClassificationPipeline(dataset_properties=dataset_properties)

def search(
self,
optimize_metric: str,
X_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_train: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
X_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
y_test: Optional[Union[List, pd.DataFrame, np.ndarray]] = None,
budget_type: Optional[str] = None,
budget: Optional[float] = None,
total_walltime_limit: int = 100,
func_eval_time_limit: int = 60,
traditional_per_total_budget: float = 0.1,
memory_limit: Optional[int] = 4096,
smac_scenario_args: Optional[Dict[str, Any]] = None,
get_smac_object_callback: Optional[Callable] = None,
all_supported_metrics: bool = True,
precision: int = 32,
disable_file_output: List = [],
load_models: bool = True,
) -> 'BaseTask':
"""
Search for the best pipeline configuration for the given dataset.

Fit both optimizes the machine learning models and builds an ensemble out of them.
To disable ensembling, set ensemble_size==0.
using the optimizer.
Args:
X_train, y_train, X_test, y_test: Union[np.ndarray, List, pd.DataFrame]
A pair of features (X_train) and targets (y_train) used to fit a
pipeline. Additionally, a holdout of this paris (X_test, y_test) can
franchuterivera marked this conversation as resolved.
Show resolved Hide resolved
be provided to track the generalization performance of each stage.
Providing X_train, y_train and dataset together is not supported.
franchuterivera marked this conversation as resolved.
Show resolved Hide resolved
optimize_metric (str): name of the metric that is used to
evaluate a pipeline.
budget_type (Optional[str]):
Type of budget to be used when fitting the pipeline.
Either 'epochs' or 'runtime'. If not provided, uses
the default in the pipeline config ('epochs')
budget (Optional[float]):
Budget to fit a single run of the pipeline. If not
provided, uses the default in the pipeline config
total_walltime_limit (int), (default=100): Time limit
in seconds for the search of appropriate models.
By increasing this value, autopytorch has a higher
chance of finding better models.
func_eval_time_limit (int), (default=60): Time limit
for a single call to the machine learning model.
Model fitting will be terminated if the machine
learning algorithm runs over the time limit. Set
this value high enough so that typical machine
learning algorithms can be fit on the training
data.
traditional_per_total_budget (float), (default=0.1):
Percent of total walltime to be allocated for
running traditional classifiers.
memory_limit (Optional[int]), (default=4096): Memory
limit in MB for the machine learning algorithm. autopytorch
will stop fitting the machine learning algorithm if it tries
to allocate more than memory_limit MB. If None is provided,
no memory limit is set. In case of multi-processing, memory_limit
will be per job. This memory limit also applies to the ensemble
creation process.
smac_scenario_args (Optional[Dict]): Additional arguments inserted
into the scenario of SMAC. See the
[SMAC documentation] (https://automl.github.io/SMAC3/master/options.html?highlight=scenario#scenario)
get_smac_object_callback (Optional[Callable]): Callback function
to create an object of class
[smac.optimizer.smbo.SMBO](https://automl.github.io/SMAC3/master/apidoc/smac.optimizer.smbo.html).
The function must accept the arguments scenario_dict,
instances, num_params, runhistory, seed and ta. This is
an advanced feature. Use only if you are familiar with
[SMAC](https://automl.github.io/SMAC3/master/index.html).
all_supported_metrics (bool), (default=True): if True, all
metrics supporting current task will be calculated
for each pipeline and results will be available via cv_results
precision (int), (default=32): Numeric precision used when loading
ensemble data. Can be either '16', '32' or '64'.
disable_file_output (Union[bool, List]):
load_models (bool), (default=True): Whether to load the
models after fitting AutoPyTorch.

Returns:
self

"""

# Fit a input validator to check the provided data
# Also, an encoder is fit to both train and test data,
# to prevent unseen categories during inference
self.InputValidator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test)

self.dataset = TabularDataset(
X=X_train, Y=y_train,
X_test=X_test, Y_test=y_test,
validator=self.InputValidator,
resampling_strategy=self.resampling_strategy,
resampling_strategy_args=self.resampling_strategy_args,
)

return self._search(
dataset=self.dataset,
optimize_metric=optimize_metric,
budget_type=budget_type,
budget=budget,
total_walltime_limit=total_walltime_limit,
func_eval_time_limit=func_eval_time_limit,
traditional_per_total_budget=traditional_per_total_budget,
memory_limit=memory_limit,
smac_scenario_args=smac_scenario_args,
get_smac_object_callback=get_smac_object_callback,
all_supported_metrics=all_supported_metrics,
precision=precision,
disable_file_output=disable_file_output,
load_models=load_models,
)

def predict(
self,
X_test: np.ndarray,
batch_size: Optional[int] = None,
n_jobs: int = 1
) -> np.ndarray:
if self.InputValidator is None or not self.InputValidator._is_fitted:
raise ValueError("predict() is only supported after calling fit. Kindly call first "
franchuterivera marked this conversation as resolved.
Show resolved Hide resolved
"the estimator fit() method.")

X_test = self.InputValidator.feature_validator.transform(X_test)
predicted_probabilities = super().predict(X_test, batch_size=batch_size,
n_jobs=n_jobs)

if self.InputValidator.target_validator.is_single_column_target():
predicted_indexes = np.argmax(predicted_probabilities, axis=1)
else:
predicted_indexes = (predicted_probabilities > 0.5).astype(int)

# Allow to predict in the original domain -- that is, the user is not interested
# in our encoded values
return self.InputValidator.target_validator.inverse_transform(predicted_indexes)

def predict_proba(self,
X_test: Union[np.ndarray, pd.DataFrame, List],
batch_size: Optional[int] = None, n_jobs: int = 1) -> np.ndarray:
if self.InputValidator is None or not self.InputValidator._is_fitted:
raise ValueError("predict() is only supported after calling fit. Kindly call first "
franchuterivera marked this conversation as resolved.
Show resolved Hide resolved
"the estimator fit() method.")
X_test = self.InputValidator.feature_validator.transform(X_test)
return super().predict(X_test, batch_size=batch_size, n_jobs=n_jobs)
1 change: 1 addition & 0 deletions autoPyTorch/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# -*- encoding: utf-8 -*-