Skip to content

Commit

Permalink
[DOC] Adds documentation to the abstract evaluator (#160)
Browse files Browse the repository at this point in the history
* DOC_153

* Changes from Ravin

* [FIX] improve clarity of msg in commit
  • Loading branch information
franchuterivera committed May 12, 2021
1 parent 8447e91 commit ac2dd99
Show file tree
Hide file tree
Showing 2 changed files with 346 additions and 3 deletions.
223 changes: 221 additions & 2 deletions autoPyTorch/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,81 @@ def send_warnings_to_log(message, category, filename, lineno,


class AbstractEvaluator(object):
"""
This method defines the interface that pipeline evaluators should follow, when
interacting with SMAC through ExecuteTaFuncWithQueue.
An evaluator is an object that:
+ constructs a pipeline (i.e. a classification or regression estimator) for a given
pipeline_config and run settings (budget, seed)
+ Fits and trains this pipeline (TrainEvaluator) or tests a given
configuration (TestEvaluator)
The provided configuration determines the type of pipeline created. For more
details, please read the get_pipeline() method.
Attributes:
backend (Backend):
An object that allows interaction with the disk storage. In particular, allows to
access the train and test datasets
queue (Queue):
Each worker available will instantiate an evaluator, and after completion,
it will append the result to a multiprocessing queue
metric (autoPyTorchMetric):
A scorer object that is able to evaluate how good a pipeline was fit. It
is a wrapper on top of the actual score method (a wrapper on top of
scikit-learn accuracy for example) that formats the predictions accordingly.
budget: (float):
The amount of epochs/time a configuration is allowed to run.
budget_type (str):
The budget type. Currently, only epoch and time are allowed.
pipeline_config (Optional[Dict[str, Any]]):
Defines the content of the pipeline being evaluated. For example, it
contains pipeline specific settings like logging name, or whether or not
to use tensorboard.
configuration (Union[int, str, Configuration]):
Determines the pipeline to be constructed. A dummy estimator is created for
integer configurations, a traditional machine learning pipeline is created
for string based configuration, and NAS is performed when a configuration
object is passed.
seed (int):
A integer that allows for reproducibility of results
output_y_hat_optimization (bool):
Whether this worker should output the target predictions, so that they are
stored on disk. Fundamentally, the resampling strategy might shuffle the
Y_train targets, so we store the split in order to re-use them for ensemble
selection.
num_run (Optional[int]):
An identifier of the current configuration being fit. This number is unique per
configuration.
include (Optional[Dict[str, Any]]):
An optional dictionary to include components of the pipeline steps.
exclude (Optional[Dict[str, Any]]):
An optional dictionary to exclude components of the pipeline steps.
disable_file_output (Union[bool, List[str]]):
By default, the model, it's predictions and other metadata is stored on disk
for each finished configuration. This argument allows the user to skip
saving certain file type, for example the model, from being written to disk.
init_params (Optional[Dict[str, Any]]):
Optional argument that is passed to each pipeline step. It is the equivalent of
kwargs for the pipeline steps.
logger_port (Optional[int]):
Logging is performed using a socket-server scheme to be robust against many
parallel entities that want to write to the same file. This integer states the
socket port for the communication channel.
If None is provided, the logging.handlers.DEFAULT_TCP_LOGGING_PORT is used.
all_supported_metrics (bool):
Whether all supported metrics should be calculated for every configuration.
search_space_updates (Optional[HyperparameterSearchSpaceUpdates]):
An object used to fine tune the hyperparameter search space of the pipeline
"""
def __init__(self, backend: Backend,
queue: Queue,
metric: autoPyTorchMetric,
budget: float,
configuration: Union[int, str, Configuration],
budget_type: str = None,
pipeline_config: Optional[Dict[str, Any]] = None,
configuration: Optional[Configuration] = None,
seed: int = 1,
output_y_hat_optimization: bool = True,
num_run: Optional[int] = None,
Expand Down Expand Up @@ -408,6 +476,23 @@ def __init__(self, backend: Backend,
self.logger.debug("Search space updates :{}".format(self.search_space_updates))

def _get_pipeline(self) -> BaseEstimator:
"""
Implements a pipeline object based on the self.configuration attribute.
int: A dummy classifier/dummy regressor is created. This estimator serves
as a baseline model to ignore all models that perform worst than this
fixed estimator. Also, in the worst case scenario, this is the final
estimator created (for instance, in case not enough memory was allocated).
str: A pipeline with traditional classifiers like random forest, SVM, etc is created,
as the configuration will contain an estimator name defining the configuration
to use, for example 'RandomForest'
Configuration: A pipeline object matching this configuration is created. This
is the case of neural architecture search, where different backbones
and head can be passed in the form of a configuration object.
Returns
pipeline (BaseEstimator):
A scikit-learn compliant pipeline which is not yet fit to the data.
"""
assert self.pipeline_class is not None, "Can't return pipeline, pipeline_class not initialised"
if isinstance(self.configuration, int):
pipeline = self.pipeline_class(config=self.configuration,
Expand Down Expand Up @@ -436,6 +521,15 @@ def _loss(self, y_true: np.ndarray, y_hat: np.ndarray) -> Dict[str, float]:
The calculate_loss internally translate a score function to
a minimization problem
Args:
y_true (np.ndarray):
The expect labels given by the original dataset
y_hat (np.ndarray):
The prediction of the current pipeline being fit
Returns:
(Dict[str, float]):
A dictionary with metric_name -> metric_loss, for every
supported metric
"""

if isinstance(self.configuration, int):
Expand All @@ -461,7 +555,39 @@ def finish_up(self, loss: Dict[str, float], train_loss: Dict[str, float],
* saving the files for the ensembles_statistics
* generate output for SMAC
We use it as the signal handler so we can recycle the code for the
normal usecase and when the runsolver kills us here :)"""
normal usecase and when the runsolver kills us here :)
Args:
loss (Dict[str, float]):
The optimization loss, calculated on the validation set. This will
be the cost used in SMAC
train_loss (Dict[str, float]):
The train loss, calculated on the train set
opt_pred (np.ndarray):
The predictions on the validation set. This validation set is created
from the resampling strategy
valid_pred (Optional[np.ndarray]):
Predictions on a user provided validation set
test_pred (Optional[np.ndarray]):
Predictions on a user provided test set
additional_run_info (Optional[Dict]):
A dictionary with additional run information, like duration or
the crash error msg, if any.
file_output (bool):
Whether or not this pipeline should output information to disk
status (StatusType)
The status of the run, following SMAC StatusType syntax.
Returns:
duration (float):
The elapsed time of the training of this evaluator
loss (float):
The optimization loss of this run
seed (int):
The seed used while fitting the pipeline
additional_info (Dict):
Additional run information, like train/test loss
"""

self.duration = time.time() - self.starttime

Expand Down Expand Up @@ -508,6 +634,25 @@ def calculate_auxiliary_losses(
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray,
) -> Tuple[Optional[float], Optional[float]]:
"""
A helper function to calculate the performance estimate of the
current pipeline in the user provided validation/test set.
Args:
Y_valid_pred (np.ndarray):
predictions on a validation set provided by the user,
matching self.y_valid
Y_test_pred (np.ndarray):
predictions on a test set provided by the user,
matching self.y_test
Returns:
validation_loss (Optional[float]):
The validation loss under the optimization metric
stored in self.metric
test_loss (Optional[float]]):
The test loss under the optimization metric
stored in self.metric
"""

validation_loss: Optional[float] = None

Expand All @@ -530,6 +675,31 @@ def file_output(
Y_valid_pred: np.ndarray,
Y_test_pred: np.ndarray
) -> Tuple[Optional[float], Dict]:
"""
This method decides what file outputs are written to disk.
It is also the interface to the backed save_numrun_to_dir
which stores all the pipeline related information to a single
directory for easy identification of the current run.
Args:
Y_optimization_pred (np.ndarray):
The pipeline predictions on the validation set internally created
from self.y_train
Y_valid_pred (np.ndarray):
The pipeline predictions on the user provided validation set,
which should match self.y_valid
Y_test_pred (np.ndarray):
The pipeline predictions on the user provided test set,
which should match self.y_test
Returns:
loss (Optional[float]):
A loss in case the run failed to store files to
disk
error_dict (Dict):
A dictionary with an error that explains why a run
was not successfully stored to disk.
"""
# Abort if self.Y_optimization is None
# self.Y_optimization can be None if we use partial-cv, then,
# obviously no output should be saved.
Expand Down Expand Up @@ -624,6 +794,23 @@ def file_output(

def _predict_proba(self, X: np.ndarray, pipeline: BaseEstimator,
Y_train: Optional[np.ndarray] = None) -> np.ndarray:
"""
A wrapper function to handle the prediction of classification tasks.
It also makes sure that the predictions has the same dimensionality
as the expected labels
Args:
X (np.ndarray):
A set of features to feed to the pipeline
pipeline (BaseEstimator):
A model that will take the features X return a prediction y
This pipeline must be a classification estimator that supports
the predict_proba method.
Y_train (Optional[np.ndarray]):
Returns:
(np.ndarray):
The predictions of pipeline for the given features X
"""
@no_type_check
def send_warnings_to_log(message, category, filename, lineno,
file=None, line=None):
Expand All @@ -640,6 +827,24 @@ def send_warnings_to_log(message, category, filename, lineno,

def _predict_regression(self, X: np.ndarray, pipeline: BaseEstimator,
Y_train: Optional[np.ndarray] = None) -> np.ndarray:
"""
A wrapper function to handle the prediction of regression tasks.
It is a wrapper to provide the same interface to _predict_proba
Regression predictions expects an unraveled dimensionality.
To comply with scikit-learn VotingRegressor requirement, if the estimator
predicts a (N,) shaped array, it is converted to (N, 1)
Args:
X (np.ndarray):
A set of features to feed to the pipeline
pipeline (BaseEstimator):
A model that will take the features X return a prediction y
Y_train (Optional[np.ndarray]):
Returns:
(np.ndarray):
The predictions of pipeline for the given features X
"""
@no_type_check
def send_warnings_to_log(message, category, filename, lineno,
file=None, line=None):
Expand All @@ -658,6 +863,20 @@ def send_warnings_to_log(message, category, filename, lineno,

def _ensure_prediction_array_sizes(self, prediction: np.ndarray,
Y_train: np.ndarray) -> np.ndarray:
"""
This method formats a prediction to match the dimensionality of the provided
labels (Y_train). This should be used exclusively for classification tasks
Args:
prediction (np.ndarray):
The un-formatted predictions of a pipeline
Y_train (np.ndarray):
The labels from the dataset to give an intuition of the expected
predictions dimensionality
Returns:
(np.ndarray):
The formatted prediction
"""
assert self.datamanager.num_classes is not None, "Called function on wrong task"
num_classes: int = self.datamanager.num_classes

Expand Down

0 comments on commit ac2dd99

Please sign in to comment.