Skip to content

Commit

Permalink
Merge branch 'adithyabsk-impute-predict-proba' into development
Browse files Browse the repository at this point in the history
  • Loading branch information
weixuanfu committed Jul 2, 2018
2 parents 643509e + f3e469b commit 28e84b2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/tpot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,26 @@ def test_predict_2():
assert result.shape == (testing_features.shape[0],)


def test_predict_3():
"""Assert that the TPOT predict function works on dataset with nan"""
tpot_obj = TPOTClassifier()
pipeline_string = (
'DecisionTreeClassifier('
'input_matrix, '
'DecisionTreeClassifier__criterion=gini, '
'DecisionTreeClassifier__max_depth=8, '
'DecisionTreeClassifier__min_samples_leaf=5, '
'DecisionTreeClassifier__min_samples_split=5'
')'
)
tpot_obj._optimized_pipeline = creator.Individual.from_string(pipeline_string, tpot_obj._pset)
tpot_obj.fitted_pipeline_ = tpot_obj._toolbox.compile(expr=tpot_obj._optimized_pipeline)
tpot_obj.fitted_pipeline_.fit(training_features, training_target)
result = tpot_obj.predict(features_with_nan)

assert result.shape == (features_with_nan.shape[0],)


def test_predict_proba():
"""Assert that the TPOT predict_proba function returns a numpy matrix of shape (num_testing_rows, num_testing_target)."""
tpot_obj = TPOTClassifier()
Expand Down Expand Up @@ -671,6 +691,27 @@ def test_predict_proba_4():
assert_raises(RuntimeError, tpot_obj.predict_proba, testing_features)


def test_predict_proba_5():
"""Assert that the TPOT predict_proba function works on dataset with nan."""
tpot_obj = TPOTClassifier()
pipeline_string = (
'DecisionTreeClassifier('
'input_matrix, '
'DecisionTreeClassifier__criterion=gini, '
'DecisionTreeClassifier__max_depth=8, '
'DecisionTreeClassifier__min_samples_leaf=5, '
'DecisionTreeClassifier__min_samples_split=5)'
)
tpot_obj._optimized_pipeline = creator.Individual.from_string(pipeline_string, tpot_obj._pset)
tpot_obj.fitted_pipeline_ = tpot_obj._toolbox.compile(expr=tpot_obj._optimized_pipeline)
tpot_obj.fitted_pipeline_.fit(training_features, training_target)

result = tpot_obj.predict_proba(features_with_nan)
num_labels = np.amax(training_target) + 1

assert result.shape == (features_with_nan.shape[0], num_labels)


def test_warm_start():
"""Assert that the TPOT warm_start flag stores the pop and pareto_front from the first run."""
tpot_obj = TPOTClassifier(
Expand Down
6 changes: 6 additions & 0 deletions tpot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,12 @@ def predict_proba(self, features):
else:
if not (hasattr(self.fitted_pipeline_, 'predict_proba')):
raise RuntimeError('The fitted pipeline does not have the predict_proba() function.')

features = features.astype(np.float64)

if np.any(np.isnan(features)):
features = self._impute_values(features)

return self.fitted_pipeline_.predict_proba(features.astype(np.float64))

def set_params(self, **params):
Expand Down

0 comments on commit 28e84b2

Please sign in to comment.