New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ExtraTrees Pipeline #790
Changes from all commits
b33bcb1
03b51d4
b0c787d
8e69c9d
5050ffd
e4cc862
1b077bf
77a71e5
30f6cfe
6612f0b
26fc87d
6543fd4
af5f67c
e44989e
5dd78cc
4207a8f
23258d3
07b15be
e66bf80
131f541
695ca3c
c7ef269
c26b2f4
af800f4
bb73610
ad65a49
570d206
3b63c1e
c69cfac
bbd28a2
84f99ca
42f93d6
0280854
34be3a2
817395a
8c03017
ea6d4f1
94759e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from evalml.pipelines import BinaryClassificationPipeline | ||
|
||
|
||
class ETBinaryClassificationPipeline(BinaryClassificationPipeline): | ||
"""Extra Trees Pipeline for binary classification""" | ||
custom_name = "Extra Trees Binary Classification Pipeline" | ||
component_graph = ['One Hot Encoder', 'Simple Imputer', 'Extra Trees Classifier'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from evalml.pipelines import MulticlassClassificationPipeline | ||
|
||
|
||
class ETMulticlassClassificationPipeline(MulticlassClassificationPipeline): | ||
"""Extra Trees Pipeline for multiclass classification""" | ||
custom_name = "Extra Trees Multiclass Classification Pipeline" | ||
component_graph = ['One Hot Encoder', 'Simple Imputer', 'Extra Trees Classifier'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from sklearn.ensemble import ExtraTreesClassifier as SKExtraTreesClassifier | ||
from skopt.space import Integer | ||
|
||
from evalml.model_family import ModelFamily | ||
from evalml.pipelines.components.estimators import Estimator | ||
from evalml.problem_types import ProblemTypes | ||
|
||
|
||
class ExtraTreesClassifier(Estimator): | ||
"""Extra Trees Classifier""" | ||
name = "Extra Trees Classifier" | ||
hyperparameter_ranges = { | ||
"n_estimators": Integer(10, 1000), | ||
"max_features": ["auto", "sqrt", "log2"], | ||
"max_depth": Integer(4, 10) | ||
} | ||
model_family = ModelFamily.EXTRA_TREES | ||
supported_problem_types = [ProblemTypes.BINARY, ProblemTypes.MULTICLASS] | ||
|
||
def __init__(self, | ||
n_estimators=100, | ||
max_features="auto", | ||
max_depth=6, | ||
min_samples_split=2, | ||
min_weight_fraction_leaf=0.0, | ||
n_jobs=-1, | ||
random_state=0): | ||
parameters = {"n_estimators": n_estimators, | ||
"max_features": max_features, | ||
"max_depth": max_depth} | ||
et_classifier = SKExtraTreesClassifier(n_estimators=n_estimators, | ||
max_features=max_features, | ||
max_depth=max_depth, | ||
min_samples_split=min_samples_split, | ||
min_weight_fraction_leaf=min_weight_fraction_leaf, | ||
n_jobs=n_jobs, | ||
random_state=random_state) | ||
super().__init__(parameters=parameters, | ||
component_obj=et_classifier, | ||
random_state=random_state) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from sklearn.ensemble import ExtraTreesRegressor as SKExtraTreesRegressor | ||
from skopt.space import Integer | ||
|
||
from evalml.model_family import ModelFamily | ||
from evalml.pipelines.components.estimators import Estimator | ||
from evalml.problem_types import ProblemTypes | ||
|
||
|
||
class ExtraTreesRegressor(Estimator): | ||
"""Extra Trees Regressor""" | ||
name = "Extra Trees Regressor" | ||
hyperparameter_ranges = { | ||
"n_estimators": Integer(10, 1000), | ||
"max_features": ["auto", "sqrt", "log2"], | ||
"max_depth": Integer(4, 10) | ||
} | ||
model_family = ModelFamily.EXTRA_TREES | ||
supported_problem_types = [ProblemTypes.REGRESSION] | ||
|
||
def __init__(self, | ||
n_estimators=100, | ||
max_features="auto", | ||
max_depth=6, | ||
min_samples_split=2, | ||
min_weight_fraction_leaf=0.0, | ||
n_jobs=-1, | ||
random_state=0): | ||
parameters = {"n_estimators": n_estimators, | ||
"max_features": max_features, | ||
"max_depth": max_depth} | ||
et_regressor = SKExtraTreesRegressor(random_state=random_state, | ||
n_estimators=n_estimators, | ||
max_features=max_features, | ||
max_depth=max_depth, | ||
min_samples_split=min_samples_split, | ||
min_weight_fraction_leaf=min_weight_fraction_leaf, | ||
n_jobs=n_jobs) | ||
super().__init__(parameters=parameters, | ||
component_obj=et_regressor, | ||
random_state=random_state) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from evalml.pipelines import RegressionPipeline | ||
|
||
|
||
class ETRegressionPipeline(RegressionPipeline): | ||
"""Extra Trees Pipeline for regression problems""" | ||
custom_name = "Extra Trees Regression Pipeline" | ||
component_graph = ['One Hot Encoder', 'Simple Imputer', 'Extra Trees Regressor'] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import numpy as np | ||
import pytest | ||
from sklearn.ensemble import ExtraTreesClassifier as SKExtraTreesClassifier | ||
|
||
from evalml.exceptions import MethodPropertyNotFoundError | ||
from evalml.model_family import ModelFamily | ||
from evalml.pipelines import ExtraTreesClassifier | ||
from evalml.problem_types import ProblemTypes | ||
|
||
|
||
def test_model_family(): | ||
assert ExtraTreesClassifier.model_family == ModelFamily.EXTRA_TREES | ||
|
||
|
||
def test_problem_types(): | ||
assert ProblemTypes.BINARY in ExtraTreesClassifier.supported_problem_types | ||
assert ProblemTypes.MULTICLASS in ExtraTreesClassifier.supported_problem_types | ||
assert len(ExtraTreesClassifier.supported_problem_types) == 2 | ||
|
||
|
||
def test_et_parameters(): | ||
|
||
clf = ExtraTreesClassifier(n_estimators=20, max_features="auto", max_depth=5, random_state=2) | ||
expected_parameters = { | ||
"n_estimators": 20, | ||
"max_features": "auto", | ||
"max_depth": 5 | ||
} | ||
|
||
assert clf.parameters == expected_parameters | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
|
||
def test_fit_predict_binary(X_y): | ||
X, y = X_y | ||
|
||
sk_clf = SKExtraTreesClassifier(max_depth=6, random_state=0) | ||
sk_clf.fit(X, y) | ||
y_pred_sk = sk_clf.predict(X) | ||
y_pred_proba_sk = sk_clf.predict_proba(X) | ||
|
||
clf = ExtraTreesClassifier() | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
y_pred_proba = clf.predict_proba(X) | ||
|
||
np.testing.assert_almost_equal(y_pred, y_pred_sk, decimal=5) | ||
np.testing.assert_almost_equal(y_pred_proba, y_pred_proba_sk, decimal=5) | ||
|
||
|
||
def test_fit_predict_multi(X_y_multi): | ||
X, y = X_y_multi | ||
|
||
sk_clf = SKExtraTreesClassifier(max_depth=6, random_state=0) | ||
sk_clf.fit(X, y) | ||
y_pred_sk = sk_clf.predict(X) | ||
y_pred_proba_sk = sk_clf.predict_proba(X) | ||
|
||
clf = ExtraTreesClassifier() | ||
clf.fit(X, y) | ||
y_pred = clf.predict(X) | ||
y_pred_proba = clf.predict_proba(X) | ||
|
||
np.testing.assert_almost_equal(y_pred, y_pred_sk, decimal=5) | ||
np.testing.assert_almost_equal(y_pred_proba, y_pred_proba_sk, decimal=5) | ||
|
||
|
||
def test_feature_importances(X_y): | ||
X, y = X_y | ||
|
||
# testing that feature importances can't be called before fit | ||
clf = ExtraTreesClassifier() | ||
with pytest.raises(MethodPropertyNotFoundError): | ||
feature_importances = clf.feature_importances | ||
|
||
sk_clf = SKExtraTreesClassifier(max_depth=6, random_state=0) | ||
sk_clf.fit(X, y) | ||
sk_feature_importances = sk_clf.feature_importances_ | ||
|
||
clf.fit(X, y) | ||
feature_importances = clf.feature_importances | ||
|
||
np.testing.assert_almost_equal(sk_feature_importances, feature_importances, decimal=5) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests are great! One thing which you can add: check that calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And same for the regressor of course lol |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep. Makes sense to have a separate family here for now. In the new automl algo this will mean the first round runs both RF and extra trees. Down the road we may want to group tree-based models different but this is a great starting point. @jeremyliweishih