-
Notifications
You must be signed in to change notification settings - Fork 89
/
estimator_checks.py
163 lines (140 loc) · 6.29 KB
/
estimator_checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""Estimator checker for extension."""
__author__ = ["fkiraly"]
__all__ = ["check_estimator"]
from inspect import isclass
def check_estimator(
estimator,
raise_exceptions=False,
tests_to_run=None,
fixtures_to_run=None,
verbose=True,
tests_to_exclude=None,
fixtures_to_exclude=None,
):
"""Run all tests on one single estimator.
Tests that are run on estimator:
all tests in test_all_estimators
all interface compatibility tests from the module of estimator's type
for example, test_all_forecasters if estimator is a forecaster
Parameters
----------
estimator : estimator class or estimator instance
raise_exceptions : bool, optional, default=False
whether to return exceptions/failures in the results dict, or raise them
if False: returns exceptions in returned `results` dict
if True: raises exceptions as they occur
tests_to_run : str or list of str, optional. Default = run all tests.
Names (test/function name string) of tests to run.
sub-sets tests that are run to the tests given here.
fixtures_to_run : str or list of str, optional. Default = run all tests.
pytest test-fixture combination codes, which test-fixture combinations to run.
sub-sets tests and fixtures to run to the list given here.
If both tests_to_run and fixtures_to_run are provided, runs the *union*,
i.e., all test-fixture combinations for tests in tests_to_run,
plus all test-fixture combinations in fixtures_to_run.
verbose : str, optional, default=True.
whether to print out informative summary of tests run.
tests_to_exclude : str or list of str, names of tests to exclude. default = None
removes tests that should not be run, after subsetting via tests_to_run.
fixtures_to_exclude : str or list of str, fixtures to exclude. default = None
removes test-fixture combinations that should not be run.
This is done after subsetting via fixtures_to_run.
Returns
-------
results : dict of results of the tests in self
keys are test/fixture strings, identical as in pytest, e.g., test[fixture]
entries are the string "PASSED" if the test passed,
or the exception raised if the test did not pass
returned only if all tests pass,
Raises
------
raises any exception produced by the tests directly
Examples
--------
>>> from aeon.transformations.exponent import ExponentTransformer
>>> from aeon.utils.estimator_checks import check_estimator
Running all tests for ExponentTransformer class,
this uses all instances from get_test_params and compatible scenarios
>>> results = check_estimator(ExponentTransformer)
All tests PASSED!
Running all tests for a specific ExponentTransformer
this uses the instance that is passed and compatible scenarios
>>> results = check_estimator(ExponentTransformer(42))
All tests PASSED!
Running specific test (all fixtures) for ExponentTransformer
>>> results = check_estimator(ExponentTransformer, tests_to_run="test_clone")
All tests PASSED!
{'test_clone[ExponentTransformer-0]': 'PASSED',
'test_clone[ExponentTransformer-1]': 'PASSED'}
Running one specific test-fixture-combination for ExponentTransformer
>>> check_estimator(
... ExponentTransformer, fixtures_to_run="test_clone[ExponentTransformer-1]"
... )
All tests PASSED!
{'test_clone[ExponentTransformer-1]': 'PASSED'}
"""
from aeon.base import BaseEstimator
from aeon.classification.early_classification.tests.test_all_early_classifiers import ( # noqa E501
TestAllEarlyClassifiers,
)
from aeon.classification.tests.test_all_classifiers import TestAllClassifiers
from aeon.forecasting.tests.test_all_forecasters import TestAllForecasters
from aeon.registry import get_identifiers
from aeon.regression.tests.test_all_regressors import TestAllRegressors
from aeon.tests.test_all_estimators import TestAllEstimators, TestAllObjects
from aeon.transformations.tests.test_all_transformers import TestAllTransformers
testclass_dict = dict()
testclass_dict["classifier"] = TestAllClassifiers
testclass_dict["early_classifier"] = TestAllEarlyClassifiers
testclass_dict["forecaster"] = TestAllForecasters
testclass_dict["regressor"] = TestAllRegressors
testclass_dict["transformer"] = TestAllTransformers
results = TestAllObjects().run_tests(
estimator=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
)
def is_estimator(obj):
"""Return whether obj is an estimator class or estimator object."""
if isclass(obj):
return issubclass(obj, BaseEstimator)
else:
return isinstance(obj, BaseEstimator)
if is_estimator(estimator):
results_estimator = TestAllEstimators().run_tests(
estimator=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
)
results.update(results_estimator)
try:
scitype_of_estimator = get_identifiers(estimator)
except Exception:
scitype_of_estimator = ""
if scitype_of_estimator in testclass_dict.keys():
results_scitype = testclass_dict[scitype_of_estimator]().run_tests(
estimator=estimator,
raise_exceptions=raise_exceptions,
tests_to_run=tests_to_run,
fixtures_to_run=fixtures_to_run,
tests_to_exclude=tests_to_exclude,
fixtures_to_exclude=fixtures_to_exclude,
)
results.update(results_scitype)
failed_tests = [key for key in results.keys() if results[key] != "PASSED"]
if len(failed_tests) > 0:
msg = failed_tests
msg = ["FAILED: " + x for x in msg]
msg = "\n".join(msg)
else:
msg = "All tests PASSED!"
if verbose:
# printing is an intended feature, for console usage and interactive debugging
print(msg) # noqa T001
return results