Skip to content

Commit

Permalink
[ENH] in check_estimator and run_tests replace `return_exceptions…
Browse files Browse the repository at this point in the history
…` arg with `raise_exceptions`, with deprecation (sktime#4030)

This PR changes the `return_exceptions` argument in some testing utilities to `raise_exceptions`, with deprecation (for 0.17.0).

Affected functions are `check_estimator`, and `QuickTester.run_tests`.

The reason for the change is that an argument `raise_exceptions` is much clearer to the user in semantics and function than `return_exceptions`. Since, in common user perception, `raise_exceptions` is "what happens" in comparison to a baseline case of it not happening.

It also becomes consistent with arguments in the testing module, e.g., conditional fixture generation functionality such as `create_conditional_fixtures_and_names` where the argument is already called `raise_exceptions`.

Also changes any internal references to the argument to the post-deprecation state.
  • Loading branch information
fkiraly authored and klam-data committed Jan 18, 2023
1 parent 3470f9d commit 6e5751b
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 23 deletions.
4 changes: 2 additions & 2 deletions docs/source/developer_guide/add_estimators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ Example: ``'test_repr[NaiveForecaster-2]'``, where ``test_repr`` is the test nam

Values of the return ``dict`` are either the string ``"PASSED"``, if the test succeeds, or the exception that the test would raise at failure.
``check_estimator`` does not raise exceptions by default, the default is returning them as dictionary values.
To raise the exceptions instead, e.g., for debugging, use the argument ``return_exceptions=False``,
To raise the exceptions instead, e.g., for debugging, use the argument ``raise_exceptions=True``,
which will raise the exceptions instead of returning them as dictionary values.
In that case, there will be at most one exception raised, namely the first exception encountered in the test execution order.

Expand Down Expand Up @@ -176,7 +176,7 @@ A useful workflow for using ``check_estimator`` to debug an estimator is as foll

1. Run ``check_estimator(MyEstimator)`` to find failing tests
2. Subset to failing tests or fixtures using ``fixtures_to_run`` or ``tests_to_run``
3. If the failure is not obvious, set ``return_exceptions=False`` to raise the exception and inspecet the traceback.
3. If the failure is not obvious, set ``raise_exceptions=True`` to raise the exception and inspecet the traceback.
4. If the failure is still not clear, use advanced debuggers on the line of code with ``check_estimator``.

Running the test suite in a repository clone
Expand Down
2 changes: 1 addition & 1 deletion examples/01b_forecasting_proba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@
"outputs": [],
"source": [
"# to raise errors for use in traceback debugging:\n",
"check_estimator(NaiveForecaster, return_exceptions=False)\n",
"check_estimator(NaiveForecaster, raise_exceptions=True)\n",
"# this does not raise an error since NaiveForecaster is fine, but would if it weren't"
]
},
Expand Down
2 changes: 1 addition & 1 deletion sktime/performance_metrics/tests/test_metrics_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,4 +129,4 @@ def custom_mape(y_true, y_pred) -> float:
score = fc_scorer(y, y)
assert isinstance(score, float)

check_estimator(fc_scorer, return_exceptions=False)
check_estimator(fc_scorer, raise_exceptions=True)
53 changes: 46 additions & 7 deletions sktime/tests/test_all_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from copy import deepcopy
from inspect import getfullargspec, isclass, signature
from tempfile import TemporaryDirectory
from warnings import warn

import joblib
import numpy as np
Expand Down Expand Up @@ -405,14 +406,26 @@ def _generate_method_nsc_arraylike(self, test_name, **kwargs):
class QuickTester:
"""Mixin class which adds the run_tests method to run tests on one estimator."""

# todo 0.17.0:
# * remove the return_exceptions arg
# * move the raise_exceptions arg to 2nd place
# * change its default to False, from None
# * update the docstring - remove return_exceptions
# * update the docstring - move raise_exceptions block to 2nd place
# * update the docstring - remove deprecation references
# * update the docstring - condition in return block, refer only to raise_exceptions
# * update the docstring - condition in raises block, refer only to raise_exceptions
# * remove the code block for input handling
# * remove import of warn
def run_tests(
self,
estimator,
return_exceptions=True,
return_exceptions=None,
tests_to_run=None,
fixtures_to_run=None,
tests_to_exclude=None,
fixtures_to_exclude=None,
raise_exceptions=None,
):
"""Run all tests on one single estimator.
Expand All @@ -430,8 +443,11 @@ def run_tests(
estimator : estimator class or estimator instance
return_exceptions : bool, optional, default=True
whether to return exceptions/failures, or raise them
if True: returns exceptions in results
if True: returns exceptions in returned `results` dict
if False: raises exceptions as they occur
deprecated in 0.15.1, and will be replaced by `raise_exceptions` in 0.17.0.
Overridden to `False` if `raise_exceptions=True`.
For safe deprecation, use `raise_exceptions` instead of `return_exceptions`.
tests_to_run : str or list of str, names of tests to run. default = all tests
sub-sets tests that are run to the tests given here.
fixtures_to_run : str or list of str, pytest test-fixture combination codes.
Expand All @@ -445,18 +461,27 @@ def run_tests(
fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
removes test-fixture combinations that should not be run.
This is done after subsetting via fixtures_to_run.
raise_exceptions : bool, optional, default=False
whether to return exceptions/failures in the results dict, or raise them
if False: returns exceptions in returned `results` dict
if True: raises exceptions as they occur
Overrides `return_exceptions` if used as a keyword argument.
both `raise_exceptions=True` and `return_exceptions=True`.
Will move to replace `return_exceptions` as 2nd arg in 0.17.0.
Returns
-------
results : dict of results of the tests in self
keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
entries are the string "PASSED" if the test passed,
or the exception raised if the test did not pass
returned only if all tests pass, or return_exceptions=True
returned only if all tests pass,
or both return_exceptions=True and raise_exceptions=False
Raises
------
if return_exception=False, raises any exception produced by the tests directly
if return_exceptions=False, or raise_exceptions=True,
raises any exception produced by the tests directly
Examples
--------
Expand All @@ -472,6 +497,22 @@ def run_tests(
... )
{'test_repr[NaiveForecaster-2]': 'PASSED'}
"""
# todo 0.17.0: remove this code block
if return_exceptions is None and raise_exceptions is None:
raise_exceptions = False

if return_exceptions is not None and raise_exceptions is None:
warn(
"The return_exceptions argument of check_estimator has been deprecated "
"since 0.15.1, and will be replaced by raise_exceptions in 0.17.0. "
"For safe deprecation: use raise_exceptions argument instead of "
"return_exceptions when using keywords. Avoid positional use, instead "
"ensure to use keywords. When not using keywords, the "
"default behaviour will not change."
)
raise_exceptions = not return_exceptions
# end block to remove

tests_to_run = self._check_None_str_or_list_of_str(
tests_to_run, var_name="tests_to_run"
)
Expand Down Expand Up @@ -546,8 +587,6 @@ def _generate_estimator_instance_cls(test_name, **kwargs):
fixture_vars = getfullargspec(test_fun)[0][1:]
fixture_vars = [var for var in fixture_sequence if var in fixture_vars]

raise_exceptions = not return_exceptions

# this call retrieves the conditional fixtures
# for the test test_name, and the estimator
_, fixture_prod, fixture_names = create_conditional_fixtures_and_names(
Expand Down Expand Up @@ -595,7 +634,7 @@ def _generate_estimator_instance_cls(test_name, **kwargs):
if fixtures_to_exclude is not None and key in fixtures_to_exclude:
continue

if return_exceptions:
if not raise_exceptions:
try:
test_fun(**deepcopy(args))
results[key] = "PASSED"
Expand Down
61 changes: 51 additions & 10 deletions sktime/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,29 @@
__all__ = ["check_estimator"]

from inspect import isclass


from warnings import warn


# todo 0.17.0:
# * remove the return_exceptions arg
# * move the raise_exceptions arg to 2nd place
# * change its default to False, from None
# * update the docstring - remove return_exceptions
# * update the docstring - move raise_exceptions block to 2nd place
# * update the docstring - remove deprecation references
# * update the docstring - condition in return block, refer only to raise_exceptions
# * update the docstring - condition in raises block to refer only to raise_exceptions
# * remove the code block for input handling
# * remove import of warn
def check_estimator(
estimator,
return_exceptions=True,
return_exceptions=None,
tests_to_run=None,
fixtures_to_run=None,
verbose=True,
tests_to_exclude=None,
fixtures_to_exclude=None,
raise_exceptions=None,
):
"""Run all tests on one single estimator.
Expand All @@ -26,10 +39,13 @@ def check_estimator(
Parameters
----------
estimator : estimator class or estimator instance
return_exception : bool, optional, default=True
return_exceptions : bool, optional, default=True
whether to return exceptions/failures, or raise them
if True: returns exceptions in results
if True: returns exceptions in returned `results` dict
if False: raises exceptions as they occur
deprecated since 0.15.1, and will be replaced by `raise_exceptions` in 0.17.0.
Overridden to `False` if `raise_exceptions=True`.
For safe deprecation, use `raise_exceptions` instead of `return_exceptions`.
tests_to_run : str or list of str, optional. Default = run all tests.
Names (test/function name string) of tests to run.
sub-sets tests that are run to the tests given here.
Expand All @@ -46,18 +62,27 @@ def check_estimator(
fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
removes test-fixture combinations that should not be run.
This is done after subsetting via fixtures_to_run.
raise_exceptions : bool, optional, default=False
whether to return exceptions/failures in the results dict, or raise them
if False: returns exceptions in returned `results` dict
if True: raises exceptions as they occur
Overrides `return_exceptions` if used as a keyword argument.
both `raise_exceptions=True` and `return_exceptions=True`.
Will move to replace `return_exceptions` as 2nd arg in 0.17.0.
Returns
-------
results : dict of results of the tests in self
keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
entries are the string "PASSED" if the test passed,
or the exception raised if the test did not pass
returned only if all tests pass, or return_exceptions=True
returned only if all tests pass,
or both return_exceptions=True and raise_exceptions=False
Raises
------
if return_exception=False, raises any exception produced by the tests directly
if return_exceptions=False, or raise_exceptions=True,
raises any exception produced by the tests directly
Examples
--------
Expand Down Expand Up @@ -103,6 +128,22 @@ def check_estimator(
from sktime.tests.test_all_estimators import TestAllEstimators, TestAllObjects
from sktime.transformations.tests.test_all_transformers import TestAllTransformers

# todo 0.17.0: remove this code block
if return_exceptions is None and raise_exceptions is None:
raise_exceptions = False

if return_exceptions is not None and raise_exceptions is None:
warn(
"The return_exceptions argument of check_estimator has been deprecated "
"since 0.15.1, and will be replaced by raise_exceptions in 0.17.0. "
"For safe deprecation: use raise_exceptions argument instead of "
"return_exceptions when using keywords. Avoid positional use, instead "
"ensure to use keywords. When not using keywords, the "
"default behaviour will not change."
)
raise_exceptions = not return_exceptions
# end block to remove

testclass_dict = dict()
testclass_dict["classifier"] = TestAllClassifiers
testclass_dict["early_classifier"] = TestAllEarlyClassifiers
Expand All @@ -114,7 +155,7 @@ def check_estimator(

results = TestAllObjects().run_tests(
estimator=estimator,
return_exceptions=return_exceptions,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
Expand All @@ -131,7 +172,7 @@ def is_estimator(obj):
if is_estimator(estimator):
results_estimator = TestAllEstimators().run_tests(
estimator=estimator,
return_exceptions=return_exceptions,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
Expand All @@ -147,7 +188,7 @@ def is_estimator(obj):
if scitype_of_estimator in testclass_dict.keys():
results_scitype = testclass_dict[scitype_of_estimator]().run_tests(
estimator=estimator,
return_exceptions=return_exceptions,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
Expand Down
4 changes: 2 additions & 2 deletions sktime/utils/tests/test_check_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def test_check_estimator_does_not_raise(estimator_class):
"""Test that check_estimator does not raise exceptions on examples we know pass."""
estimator_instance = estimator_class.create_test_instance()

check_estimator(estimator_class, return_exceptions=False, verbose=False)
check_estimator(estimator_class, raise_exceptions=True, verbose=False)

check_estimator(estimator_instance, return_exceptions=False, verbose=False)
check_estimator(estimator_instance, raise_exceptions=True, verbose=False)


def test_check_estimator_subset_tests():
Expand Down

0 comments on commit 6e5751b

Please sign in to comment.