Skip to content

Commit

Permalink
Merge a45046d into e7e4af5
Browse files Browse the repository at this point in the history
  • Loading branch information
beckernick committed Sep 3, 2020
2 parents e7e4af5 + a45046d commit 92b4109
Show file tree
Hide file tree
Showing 10 changed files with 1,056 additions and 16 deletions.
10 changes: 10 additions & 0 deletions docs_sources/using.md
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,16 @@ Currently only classification is supported, but future releases will include reg
<td align="center"><a href="https://github.com/EpistasisLab/tpot/blob/master/tpot/config/classifier_nn.py">Classification</a></td>
</tr>

<tr>
<td>TPOT-cuML</td>
<td>TPOT will search over a restricted configuration using the GPU-accelerated estimators in <a href="https://github.com/rapidsai/cuml">RAPIDS cuML</a> and <a href="https://github.com/dmlc/xgboost">DMLC XGBoost</a>. This configuration requires an NVIDIA Pascal architecture or better GPU with compute capability 6.0+, and that the library cuML is installed. With this configuration, all model training and predicting will be GPU-accelerated.
<br /><br />
This configuration is particularly useful for medium-sized and larger datasets on which CPU-based estimators are a common bottleneck, and works for both the TPOTClassifier and TPOTRegressor.</td>
<td align="center"><a href="https://github.com/EpistasisLab/tpot/blob/master/tpot/config/classifier_cuml.py">Classification</a>
<br /><br />
<a href="https://github.com/EpistasisLab/tpot/blob/master/tpot/config/regressor_cuml.py">Regression</a></td>
</tr>

</table>

To use any of these configurations, simply pass the string name of the configuration to the `config_dict` parameter (or `-config` on the command line). For example, to use the "TPOT light" configuration:
Expand Down
47 changes: 46 additions & 1 deletion tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

from tpot import TPOTClassifier, TPOTRegressor
from tpot.base import TPOTBase
from tpot.base import TPOTBase, _has_cuml
from tpot.driver import float_range
from tpot.gp_types import Output_Array
from tpot.gp_deap import mutNodeReplacement, _wrapped_cross_val_score, pick_two_individuals_eligible_for_crossover, cxOnePoint, varOr, initialize_stats_dict
Expand All @@ -39,6 +39,8 @@
from tpot.config.regressor_sparse import regressor_config_sparse
from tpot.config.classifier_sparse import classifier_config_sparse
from tpot.config.classifier_nn import classifier_config_nn
from tpot.config.classifier_cuml import classifier_config_cuml
from tpot.config.regressor_cuml import regressor_config_cuml

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -515,6 +517,15 @@ def test_conf_dict():
tpot_obj._fit_init()
assert tpot_obj._config_dict == regressor_config_sparse

if _has_cuml():
tpot_obj = TPOTClassifier(config_dict='TPOT cuML')
tpot_obj._fit_init()
assert tpot_obj._config_dict == classifier_config_cuml

tpot_obj = TPOTRegressor(config_dict='TPOT cuML')
tpot_obj._fit_init()
assert tpot_obj._config_dict == regressor_config_cuml


def test_conf_dict_2():
"""Assert that TPOT uses a custom dictionary of operators when config_dict is Python dictionary."""
Expand Down Expand Up @@ -1109,6 +1120,40 @@ def test_fit_7():
assert not (tpot_obj._start_datetime is None)


def test_fit_cuml():
"""Assert that the TPOT fit function provides an optimized pipeline when config_dict is 'TPOT cuML' if cuML is available. If not available, assert _fit_init raises a ValueError."""

tpot_clf_obj = TPOTClassifier(
random_state=42,
population_size=1,
offspring_size=2,
generations=1,
verbosity=0,
config_dict='TPOT cuML'
)

tpot_regr_obj = TPOTRegressor(
random_state=42,
population_size=1,
offspring_size=2,
generations=1,
verbosity=0,
config_dict='TPOT cuML'
)

if _has_cuml():
tpot_clf_obj.fit(training_features, training_target)
assert isinstance(tpot_clf_obj._optimized_pipeline, creator.Individual)
assert not (tpot_clf_obj._start_datetime is None)

tpot_regr_obj.fit(pretest_X_reg, pretest_y_reg)
assert isinstance(tpot_regr_obj._optimized_pipeline, creator.Individual)
assert not (tpot_regr_obj._start_datetime is None)
else:
assert_raises(ValueError, tpot_clf_obj._fit_init)
assert_raises(ValueError, tpot_regr_obj._fit_init)


def test_memory():
"""Assert that the TPOT fit function runs normally with memory=\'auto\'."""
tpot_obj = TPOTClassifier(
Expand Down
20 changes: 20 additions & 0 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@
from .config.regressor_sparse import regressor_config_sparse
from .config.classifier_sparse import classifier_config_sparse
from .config.classifier_nn import classifier_config_nn
from .config.classifier_cuml import classifier_config_cuml
from .config.regressor_cuml import regressor_config_cuml

from .metrics import SCORERS
from .gp_types import Output_Array
Expand Down Expand Up @@ -345,6 +347,16 @@ def _setup_config(self, config_dict):
self._config_dict = regressor_config_sparse
elif config_dict == 'TPOT NN':
self._config_dict = classifier_config_nn
elif config_dict == 'TPOT cuML':
if not _has_cuml():
raise ValueError(
'The GPU machine learning library cuML is not '
'available. To use cuML, please install cuML via conda.'
)
elif self.classification:
self._config_dict = classifier_config_cuml
else:
self._config_dict = regressor_config_cuml
else:
config = self._read_config_file(config_dict)
if hasattr(config, 'tpot_config'):
Expand Down Expand Up @@ -1722,3 +1734,11 @@ def _generate(self, pset, min_, max_, condition, type_=None):
for arg in reversed(prim.args):
stack.append((depth + 1, arg))
return expr


def _has_cuml():
try:
import cuml
return True
except ImportError:
return False
5 changes: 2 additions & 3 deletions tpot/builtins/stacking_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""

import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin, ClassifierMixin
from sklearn.base import BaseEstimator, TransformerMixin, is_classifier
from sklearn.utils import check_array


Expand Down Expand Up @@ -83,13 +83,12 @@ def transform(self, X):
X = check_array(X)
X_transformed = np.copy(X)
# add class probabilities as a synthetic feature
if issubclass(self.estimator.__class__, ClassifierMixin) and hasattr(self.estimator, 'predict_proba'):
if is_classifier(self.estimator) and hasattr(self.estimator, 'predict_proba'):
y_pred_proba = self.estimator.predict_proba(X)
# check all values that should be not infinity or not NAN
if np.all(np.isfinite(y_pred_proba)):
X_transformed = np.hstack((y_pred_proba, X))

# add class prediction as a synthetic feature
X_transformed = np.hstack((np.reshape(self.estimator.predict(X), (-1, 1)), X_transformed))

return X_transformed
130 changes: 130 additions & 0 deletions tpot/config/classifier_cuml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-

"""This file is part of the TPOT library.
TPOT was primarily developed at the University of Pennsylvania by:
- Randal S. Olson (rso@randalolson.com)
- Weixuan Fu (weixuanf@upenn.edu)
- Daniel Angell (dpa34@drexel.edu)
- and many more generous open source contributors
TPOT is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as
published by the Free Software Foundation, either version 3 of
the License, or (at your option) any later version.
TPOT is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public
License along with TPOT. If not, see <http://www.gnu.org/licenses/>.
"""

import numpy as np

# This configuration provides users with access to a GPU the ability to
# use RAPIDS cuML and DMLC/XGBoost classifiers as estimators alongside
# the scikit-learn preprocessors in the TPOT default configuration.

classifier_config_cuml = {
# cuML + DMLC/XGBoost Classifiers

"cuml.neighbors.KNeighborsClassifier": {
"n_neighbors": range(1, 101),
"weights": ["uniform",],
},

"cuml.linear_model.LogisticRegression": {
"penalty": ["l1", "l2", "elasticnet"],
"C": [1e-4, 1e-3, 1e-2, 1e-1, 0.5, 1., 5., 10., 15., 20., 25.,],
},

"xgboost.XGBClassifier": {
"n_estimators": [100],
"max_depth": range(3, 10),
"learning_rate": [1e-2, 1e-1, 0.5, 1.],
"subsample": np.arange(0.05, 1.01, 0.05),
"min_child_weight": range(1, 21),
"alpha": [1, 10],
"tree_method": ["gpu_hist"],
"nthread": [1]
},

# Sklearn Preprocesssors

"sklearn.preprocessing.Binarizer": {
"threshold": np.arange(0.0, 1.01, 0.05)
},

"sklearn.decomposition.FastICA": {
"tol": np.arange(0.0, 1.01, 0.05)
},

"sklearn.cluster.FeatureAgglomeration": {
"linkage": ["ward", "complete", "average"],
"affinity": ["euclidean", "l1", "l2", "manhattan", "cosine"]
},

"sklearn.preprocessing.MaxAbsScaler": {
},

"sklearn.preprocessing.MinMaxScaler": {
},

"sklearn.preprocessing.Normalizer": {
"norm": ["l1", "l2", "max"]
},

"sklearn.kernel_approximation.Nystroem": {
"kernel": ["rbf", "cosine", "chi2", "laplacian", "polynomial", "poly", "linear", "additive_chi2", "sigmoid"],
"gamma": np.arange(0.0, 1.01, 0.05),
"n_components": range(1, 11)
},

"sklearn.decomposition.PCA": {
"svd_solver": ["randomized"],
"iterated_power": range(1, 11)
},

"sklearn.kernel_approximation.RBFSampler": {
"gamma": np.arange(0.0, 1.01, 0.05)
},

"sklearn.preprocessing.RobustScaler": {
},

"sklearn.preprocessing.StandardScaler": {
},

"tpot.builtins.ZeroCount": {
},

"tpot.builtins.OneHotEncoder": {
"minimum_fraction": [0.05, 0.1, 0.15, 0.2, 0.25],
"sparse": [False],
"threshold": [10]
},

# Selectors

"sklearn.feature_selection.SelectFwe": {
"alpha": np.arange(0, 0.05, 0.001),
"score_func": {
"sklearn.feature_selection.f_classif": None
}
},

"sklearn.feature_selection.SelectPercentile": {
"percentile": range(1, 100),
"score_func": {
"sklearn.feature_selection.f_classif": None
}
},

"sklearn.feature_selection.VarianceThreshold": {
"threshold": [0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2]
}
}

0 comments on commit 92b4109

Please sign in to comment.