Skip to content

Commit

Permalink
ref: remove MERGE regressor in favor of Extra Trees with larger n_est…
Browse files Browse the repository at this point in the history
…imators
  • Loading branch information
paulmueller committed Nov 5, 2018
1 parent f67a950 commit b594135
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 34 deletions.
32 changes: 3 additions & 29 deletions nanite/rate/regressors.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,12 @@
"""scikit-learn regressors and their keyword arguments"""
import numpy as np
from sklearn import ensemble, svm, tree


class AverageTreeRegressor(object):
def __init__(self, ensemble_names, random_states):
for n, rs in zip(ensemble_names, random_states):
kwargs = reg_dict[n][1].copy()
kwargs["random_state"] = rs
reg_dict[n][0](**kwargs)

self.ensemble_regs = [reg_dict[n][0](**reg_dict[n][1])
for n in ensemble_names]

def fit(self, *args, **kwargs):
[er.fit(*args, **kwargs) for er in self.ensemble_regs]

def predict(self, *args, **kwargs):
vals = [er.predict(*args, **kwargs) for er in self.ensemble_regs]
return np.mean(np.array(vals), axis=0)


reg_dict = {
"AdaBoost": [
ensemble.AdaBoostRegressor,
{"learning_rate": .5,
"n_estimators": 30,
"n_estimators": 100,
"random_state": 42,
},
],
Expand All @@ -42,7 +23,7 @@ def predict(self, *args, **kwargs):
{"max_depth": 15,
"min_samples_leaf": 2,
"min_samples_split": 5,
"n_estimators": 10,
"n_estimators": 100,
"random_state": 42,
}
],
Expand All @@ -55,18 +36,12 @@ def predict(self, *args, **kwargs):
"random_state": 42,
}
],
"MERGE": [
AverageTreeRegressor,
{"ensemble_names": ["Extra Trees", "Random Forest",
"Gradient Tree Boosting"],
"random_states": [42, 42, 42]}
],
"Random Forest": [
ensemble.RandomForestRegressor,
{"max_depth": 15,
"min_samples_leaf": 2,
"min_samples_split": 7,
"n_estimators": 10,
"n_estimators": 100,
"random_state": 42,
}
],
Expand Down Expand Up @@ -94,7 +69,6 @@ def predict(self, *args, **kwargs):
#: List of tree-based regressor class names (used for keyword defaults in
#: :class:`IndentationRater`)
reg_trees = ["AdaBoostRegressor",
"AverageTreeRegressor",
"DecisionTreeRegressor",
"ExtraTreesRegressor",
"GradientBoostingRegressor",
Expand Down
9 changes: 4 additions & 5 deletions tests/test_qmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,13 @@ def test_feat_rating():
y_axis="force",
segment="approach",
weight_cp=2e-6)
idnt.rate_quality(method="MERGE", ts_label="zef18")
idnt.rate_quality(method="Extra Trees", ts_label="zef18")

qd = qm.get_qmap("meta rating", qmap_only=True)
vals = qd.flat[~np.isnan(qd.flat)]

assert np.allclose(vals[0], 9.765420865311263), "gray matter"
assert np.allclose(vals[2], 4.981720718347044), "white matter"
assert np.allclose(vals[1], 1.7713407492968665), "background"
assert np.allclose(vals[0], 9.471932624275558), "gray matter"
assert np.allclose(vals[2], 4.75182041147194), "white matter"
assert np.allclose(vals[1], 2.568823857492953), "background"


def test_feat_rating_nofit():
Expand Down

0 comments on commit b594135

Please sign in to comment.