/
transform.py
67 lines (52 loc) · 2.06 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""transform_feature_names implementations for scikit-learn transformers
"""
import numpy as np
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.feature_selection.base import SelectorMixin
from sklearn.preprocessing import (
MinMaxScaler,
StandardScaler,
MaxAbsScaler,
RobustScaler,
)
from eli5.transform import transform_feature_names
from eli5.sklearn.utils import get_feature_names as _get_feature_names
# Feature selection:
@transform_feature_names.register(SelectorMixin)
def _select_names(est, in_names=None):
mask = est.get_support(indices=False)
in_names = _get_feature_names(est, feature_names=in_names,
num_features=len(mask))
return [in_names[i] for i in np.flatnonzero(mask)]
try:
from sklearn.linear_model import (
RandomizedLogisticRegression,
RandomizedLasso,
)
_select_names = transform_feature_names.register(RandomizedLasso)(_select_names)
_select_names = transform_feature_names.register(RandomizedLogisticRegression)(_select_names)
except ImportError: # Removed in scikit-learn 0.21
pass
# Scaling
@transform_feature_names.register(MinMaxScaler)
@transform_feature_names.register(StandardScaler)
@transform_feature_names.register(MaxAbsScaler)
@transform_feature_names.register(RobustScaler)
def _transform_scaling(est, in_names=None):
if in_names is None:
in_names = _get_feature_names(est, feature_names=in_names,
num_features=est.scale_.shape[0])
return [name for name in in_names]
# Pipelines
@transform_feature_names.register(Pipeline)
def _pipeline_names(est, in_names=None):
names = in_names
for name, trans in est.steps:
if trans is not None:
names = transform_feature_names(trans, names)
return names
@transform_feature_names.register(FeatureUnion)
def _union_names(est, in_names=None):
return ['{}:{}'.format(trans_name, feat_name)
for trans_name, trans, _ in est._iter()
for feat_name in transform_feature_names(trans, in_names)]