/
mixin.py
85 lines (72 loc) · 2.82 KB
/
mixin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class TransformerMixin(object):
"""Mixin class for all transformers in scikit-learn."""
def fit_transform(self, X, y=None, **fit_params):
"""Fit to data, then transform it.
Fits transformer to ``X`` and ``y`` with optional parameters
``fit_params``, and returns a transformed version of ``X``.
Parameters
----------
X : array, shape (n_samples, n_features)
Training set.
y : array, shape (n_samples,)
Target values or class labels.
**fit_params : dict
Additional fitting parameters passed to the ``fit`` method..
Returns
-------
X_new : array, shape (n_samples, n_features_new)
Transformed array.
"""
# non-optimized default implementation; override when a better
# method is possible for a given clustering algorithm
if y is None:
# fit method of arity 1 (unsupervised transformation)
return self.fit(X, **fit_params).transform(X)
else:
# fit method of arity 2 (supervised transformation)
return self.fit(X, y, **fit_params).transform(X)
class EstimatorMixin(object):
"""Mixin class for estimators."""
def get_params(self, deep=True):
"""Get the estimator params.
Parameters
----------
deep : bool
Deep.
"""
return
def set_params(self, **params):
"""Set parameters (mimics sklearn API).
Parameters
----------
**params : dict
Extra parameters.
Returns
-------
inst : object
The instance.
"""
if not params:
return self
valid_params = self.get_params(deep=True)
for key, value in params.items():
split = key.split('__', 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
else:
# simple objects case
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self.__class__.__name__))
setattr(self, key, value)
return self