Skip to content

Commit

Permalink
Enable Mypy in evaluation (except Train Evaluator) (#1077)
Browse files Browse the repository at this point in the history
* Almost all files for evaluation

* Feedback from PR

* Feedback from comments

* Solving rebase artifacts

* Revert bytes
  • Loading branch information
franchuterivera committed Feb 17, 2021
1 parent 13f9c1f commit cda9876
Show file tree
Hide file tree
Showing 7 changed files with 556 additions and 328 deletions.
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ repos:
args: [--show-error-codes]
name: mypy auto-sklearn-util
files: autosklearn/util
- id: mypy
args: [--show-error-codes]
name: mypy auto-sklearn-evaluation
files: autosklearn/evaluation
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
hooks:
Expand Down
68 changes: 50 additions & 18 deletions autosklearn/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from queue import Empty
import time
import traceback
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from ConfigSpace import Configuration
import numpy as np
import pynisher
from smac.runhistory.runhistory import RunInfo, RunValue
from smac.stats.stats import Stats
from smac.tae import StatusType, TAEAbortException
from smac.tae.execute_func import AbstractTAFunc

Expand All @@ -23,11 +24,17 @@
import autosklearn.evaluation.train_evaluator
import autosklearn.evaluation.test_evaluator
import autosklearn.evaluation.util
from autosklearn.util.logging_ import get_named_client_logger
from autosklearn.evaluation.train_evaluator import TYPE_ADDITIONAL_INFO
from autosklearn.util.backend import Backend
from autosklearn.util.logging_ import PickableLoggerAdapter, get_named_client_logger
from autosklearn.util.parallel import preload_modules


def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
def fit_predict_try_except_decorator(
ta: Callable,
queue: multiprocessing.Queue,
cost_for_crash: float,
**kwargs: Any) -> None:

try:
return ta(queue=queue, **kwargs)
Expand Down Expand Up @@ -66,7 +73,7 @@ def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
queue.close()


def get_cost_of_crash(metric):
def get_cost_of_crash(metric: Scorer) -> float:

# The metric must always be defined to extract optimum/worst
if not isinstance(metric, Scorer):
Expand All @@ -85,8 +92,11 @@ def get_cost_of_crash(metric):
return worst_possible_result


def _encode_exit_status(exit_status):
def _encode_exit_status(exit_status: Union[str, int, Type[BaseException]]
) -> Union[str, int]:
try:
# If it can be dumped, then it is int
exit_status = cast(int, exit_status)
json.dumps(exit_status)
return exit_status
except (TypeError, OverflowError):
Expand All @@ -97,13 +107,31 @@ def _encode_exit_status(exit_status):
# easier debugging of potential crashes
class ExecuteTaFuncWithQueue(AbstractTAFunc):

def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
cost_for_crash, abort_on_first_run_crash, port, pynisher_context,
initial_num_run=1, stats=None,
run_obj='quality', par_factor=1, scoring_functions=None,
output_y_hat_optimization=True, include=None, exclude=None,
memory_limit=None, disable_file_output=False, init_params=None,
budget_type=None, ta=False, **resampling_strategy_args):
def __init__(
self,
backend: Backend,
autosklearn_seed: int,
resampling_strategy: Union[str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit],
metric: Scorer,
cost_for_crash: float,
abort_on_first_run_crash: bool,
port: int,
pynisher_context: str,
initial_num_run: int = 1,
stats: Optional[Stats] = None,
run_obj: str = 'quality',
par_factor: int = 1,
scoring_functions: Optional[List[Scorer]] = None,
output_y_hat_optimization: bool = True,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
memory_limit: Optional[int] = None,
disable_file_output: bool = False,
init_params: Optional[Dict[str, Any]] = None,
budget_type: Optional[str] = None,
ta: Optional[Callable] = None,
**resampling_strategy_args: Any,
):

if resampling_strategy == 'holdout':
eval_function = autosklearn.evaluation.train_evaluator.eval_holdout
Expand Down Expand Up @@ -180,7 +208,7 @@ def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
self.port = port
self.pynisher_context = pynisher_context
if self.port is None:
self.logger = logging.getLogger("TAE")
self.logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("TAE")
else:
self.logger = get_named_client_logger(
name="TAE",
Expand Down Expand Up @@ -261,6 +289,10 @@ def run(
instance_specific: Optional[str] = None,
) -> Tuple[StatusType, float, float, Dict[str, Union[int, float, str, Dict, List, Tuple]]]:

# Additional information of each of the tae executions
# Defined upfront for mypy
additional_run_info: TYPE_ADDITIONAL_INFO = {}

context = multiprocessing.get_context(self.pynisher_context)
preload_modules(context)
queue = context.Queue()
Expand All @@ -272,7 +304,7 @@ def run(
init_params.update(self.init_params)

if self.port is None:
logger = logging.getLogger("pynisher")
logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("pynisher")
else:
logger = get_named_client_logger(
name="pynisher",
Expand Down Expand Up @@ -320,11 +352,11 @@ def run(
except Exception as e:
exception_traceback = traceback.format_exc()
error_message = repr(e)
additional_info = {
additional_run_info.update({
'traceback': exception_traceback,
'error': error_message
}
return StatusType.CRASHED, self.cost_for_crash, 0.0, additional_info
})
return StatusType.CRASHED, self.worst_possible_result, 0.0, additional_run_info

if obj.exit_status in (pynisher.TimeoutException, pynisher.MemorylimitException):
# Even if the pynisher thinks that a timeout or memout occured,
Expand Down Expand Up @@ -359,7 +391,7 @@ def run(
elif obj.exit_status is pynisher.MemorylimitException:
status = StatusType.MEMOUT
additional_run_info = {
'error': 'Memout (used more than %d MB).' % self.memory_limit
"error": "Memout (used more than {} MB).".format(self.memory_limit)
}
else:
raise ValueError(obj.exit_status)
Expand Down
Loading

0 comments on commit cda9876

Please sign in to comment.