Skip to content

Commit

Permalink
Merge pull request #916 from jhmenke/development
Browse files Browse the repository at this point in the history
Enable custom scorer object for parallel computation #914
  • Loading branch information
weixuanfu committed Sep 16, 2019
2 parents d009554 + eafbe13 commit fc25bad
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 14 deletions.
10 changes: 5 additions & 5 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_init_default_scoring_2():
assert len(w) == 1 # deap 1.2.2 warning message made this unit test failed
assert issubclass(w[-1].category, DeprecationWarning) # deap 1.2.2 warning message made this unit test failed
assert "This scoring type was deprecated" in str(w[-1].message) # deap 1.2.2 warning message made this unit test failed
assert tpot_obj.scoring_function == 'balanced_accuracy'
assert tpot_obj.scoring_function._score_func == balanced_accuracy


def test_init_default_scoring_3():
Expand All @@ -191,7 +191,7 @@ def test_init_default_scoring_3():
tpot_obj = TPOTClassifier(scoring=make_scorer(balanced_accuracy))
tpot_obj._fit_init()
assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed
assert tpot_obj.scoring_function == 'balanced_accuracy'
assert tpot_obj.scoring_function._score_func == balanced_accuracy


def test_init_default_scoring_4():
Expand All @@ -203,7 +203,7 @@ def my_scorer(clf, X, y):
tpot_obj = TPOTClassifier(scoring=my_scorer)
tpot_obj._fit_init()
assert len(w) == 0 # deap 1.2.2 warning message made this unit test failed
assert tpot_obj.scoring_function == 'my_scorer'
assert tpot_obj.scoring_function == my_scorer


def test_init_default_scoring_5():
Expand All @@ -214,7 +214,7 @@ def test_init_default_scoring_5():
assert len(w) == 1
assert issubclass(w[-1].category, DeprecationWarning)
assert "This scoring type was deprecated" in str(w[-1].message)
assert tpot_obj.scoring_function == 'roc_auc_score'
assert tpot_obj.scoring_function._score_func == roc_auc_score


def test_init_default_scoring_6():
Expand All @@ -228,7 +228,7 @@ def my_scorer(y_true, y_pred):
assert issubclass(w[-1].category, DeprecationWarning)
assert "This scoring type was deprecated" in str(w[-1].message)

assert tpot_obj.scoring_function == 'my_scorer'
assert tpot_obj.scoring_function._score_func == my_scorer


def test_invalid_score_warning():
Expand Down
19 changes: 10 additions & 9 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def _setup_scoring_function(self, scoring):
'choose a valid scoring function from the TPOT '
'documentation.'.format(scoring)
)
self.scoring_function = scoring
elif callable(scoring):
# Heuristic to ensure user has not passed a metric
module = getattr(scoring, '__module__', None)
Expand All @@ -342,21 +343,15 @@ def _setup_scoring_function(self, scoring):
not module.startswith('sklearn.metrics.tests.')):
scoring_name = scoring.__name__
greater_is_better = 'loss' not in scoring_name and 'error' not in scoring_name
SCORERS[scoring_name] = make_scorer(scoring, greater_is_better=greater_is_better)
self.scoring_function = make_scorer(scoring, greater_is_better=greater_is_better)
warnings.simplefilter('always', DeprecationWarning)
warnings.warn('Scoring function {} looks like it is a metric function '
'rather than a scikit-learn scorer. This scoring type was deprecated '
'in version TPOT 0.9.1 and will be removed in version 0.11. '
'Please update your custom scoring function.'.format(scoring), DeprecationWarning)
else:
if isinstance(scoring, _BaseScorer):
scoring_name = scoring._score_func.__name__
else:
scoring_name = scoring.__name__
SCORERS[scoring_name] = scoring
scoring = scoring_name
self.scoring_function = scoring

self.scoring_function = scoring

def _setup_config(self, config_dict):
if config_dict:
Expand Down Expand Up @@ -969,7 +964,13 @@ def score(self, testing_features, testing_target):

# If the scoring function is a string, we must adjust to use the sklearn
# scoring interface
score = SCORERS[self.scoring_function](
if isinstance(self.scoring_function, str):
scorer = SCORERS[self.scoring_function]
elif callable(self.scoring_function):
scorer = self.scoring_function
else:
raise RuntimeError('The scoring function should either be the name of a scikit-learn scorer or a scorer object')
score = scorer(
self.fitted_pipeline_,
testing_features.astype(np.float64),
testing_target.astype(np.float64)
Expand Down

0 comments on commit fc25bad

Please sign in to comment.