forked from scikit-learn/scikit-learn
/
test_estimator_checks.py
156 lines (126 loc) · 5.35 KB
/
test_estimator_checks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import scipy.sparse as sp
import numpy as np
import sys
from sklearn.externals.six.moves import cStringIO as StringIO
from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils.testing import assert_raises_regex, assert_true
from sklearn.utils.estimator_checks import check_estimator
from sklearn.utils.estimator_checks import check_estimators_unfitted
from sklearn.ensemble import AdaBoostClassifier
from sklearn.linear_model import MultiTaskElasticNet
from sklearn.utils.validation import check_X_y, check_array
class CorrectNotFittedError(ValueError):
"""Exception class to raise if estimator is used before fitting.
Like NotFittedError, it inherits from ValueError, but not from
AttributeError. Used for testing only.
"""
class BaseBadClassifier(BaseEstimator, ClassifierMixin):
def fit(self, X, y):
return self
def predict(self, X):
return np.ones(X.shape[0])
class ChangesDict(BaseEstimator):
def __init__(self):
self.key = 0
def fit(self, X, y=None):
X, y = check_X_y(X, y)
return self
def predict(self, X):
X = check_array(X)
self.key = 1000
return np.ones(X.shape[0])
class NoCheckinPredict(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y)
return self
class NoSparseClassifier(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
if sp.issparse(X):
raise ValueError("Nonsensical Error")
return self
def predict(self, X):
X = check_array(X)
return np.ones(X.shape[0])
class CorrectNotFittedErrorClassifier(BaseBadClassifier):
def fit(self, X, y):
X, y = check_X_y(X, y)
self.coef_ = np.ones(X.shape[1])
return self
def predict(self, X):
if not hasattr(self, 'coef_'):
raise CorrectNotFittedError("estimator is not fitted yet")
X = check_array(X)
return np.ones(X.shape[0])
class NoSampleWeightPandasSeriesType(BaseEstimator):
def fit(self, X, y, sample_weight=None):
# Convert data
X, y = check_X_y(X, y,
accept_sparse=("csr", "csc"),
multi_output=True,
y_numeric=True)
# Function is only called after we verify that pandas is installed
from pandas import Series
if isinstance(sample_weight, Series):
raise ValueError("Estimator does not accept 'sample_weight'"
"of type pandas.Series")
return self
def predict(self, X):
X = check_array(X)
return np.ones(X.shape[0])
def test_check_estimator():
# tests that the estimator actually fails on "bad" estimators.
# not a complete test of all checks, which are very extensive.
# check that we have a set_params and can clone
msg = "it does not implement a 'get_params' methods"
assert_raises_regex(TypeError, msg, check_estimator, object)
# check that we have a fit method
msg = "object has no attribute 'fit'"
assert_raises_regex(AttributeError, msg, check_estimator, BaseEstimator)
# check that fit does input validation
msg = "TypeError not raised"
assert_raises_regex(AssertionError, msg, check_estimator, BaseBadClassifier)
# check that sample_weights in fit accepts pandas.Series type
try:
from pandas import Series # noqa
msg = ("Estimator NoSampleWeightPandasSeriesType raises error if "
"'sample_weight' parameter is of type pandas.Series")
assert_raises_regex(
ValueError, msg, check_estimator, NoSampleWeightPandasSeriesType)
except ImportError:
pass
# check that predict does input validation (doesn't accept dicts in input)
msg = "Estimator doesn't check for NaN and inf in predict"
assert_raises_regex(AssertionError, msg, check_estimator, NoCheckinPredict)
# check that estimator state does not change
# at transform/predict/predict_proba time
msg = 'Estimator changes __dict__ during predict'
assert_raises_regex(AssertionError, msg, check_estimator, ChangesDict)
# check for sparse matrix input handling
name = NoSparseClassifier.__name__
msg = "Estimator " + name + " doesn't seem to fail gracefully on sparse data"
# the check for sparse input handling prints to the stdout,
# instead of raising an error, so as not to remove the original traceback.
# that means we need to jump through some hoops to catch it.
old_stdout = sys.stdout
string_buffer = StringIO()
sys.stdout = string_buffer
try:
check_estimator(NoSparseClassifier)
except:
pass
finally:
sys.stdout = old_stdout
assert_true(msg in string_buffer.getvalue())
# doesn't error on actual estimator
check_estimator(AdaBoostClassifier)
check_estimator(MultiTaskElasticNet)
def test_check_estimators_unfitted():
# check that a ValueError/AttributeError is raised when calling predict
# on an unfitted estimator
msg = "AttributeError or ValueError not raised by predict"
assert_raises_regex(AssertionError, msg, check_estimators_unfitted,
"estimator", NoSparseClassifier)
# check that CorrectNotFittedError inherit from either ValueError
# or AttributeError
check_estimators_unfitted("estimator", CorrectNotFittedErrorClassifier)