Skip to content

Commit

Permalink
[MRG+1] fixed OOB_Score bug for bagging classifiers. (scikit-learn#8936)
Browse files Browse the repository at this point in the history
* fixed OOB_Score bug for bagging slassifiers.
See: scikit-learn#8933

* Added white space

* more white space fixing

* Adding test for oob_score validity

* removing pandas, replacing with numpy matrices

* fixing white space

* more white space fixing

* white space ...

* fixed labels to allow for strings

* white space

* simplifying test

* white space

* reformatting test

* white space

* pressed enter at end of file

* removing line at end of file
  • Loading branch information
mlewis1729 authored and AishwaryaRK committed Aug 29, 2017
1 parent c35dc98 commit 3816f82
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
3 changes: 1 addition & 2 deletions sklearn/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,8 +608,7 @@ def _set_oob_score(self, X, y):

oob_decision_function = (predictions /
predictions.sum(axis=1)[:, np.newaxis])
oob_score = accuracy_score(y, classes_.take(np.argmax(predictions,
axis=1)))
oob_score = accuracy_score(y, np.argmax(predictions, axis=1))

self.oob_decision_function_ = oob_decision_function
self.oob_score_ = oob_score
Expand Down
17 changes: 17 additions & 0 deletions sklearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,20 @@ def test_max_samples_consistency():
max_features=0.5, random_state=1)
bagging.fit(X, y)
assert_equal(bagging._max_samples, max_samples)


def test_set_oob_score_label_encoding():
# Make sure the oob_score doesn't change when the labels change
# See: https://github.com/scikit-learn/scikit-learn/issues/8933
randState = 5
X = [[-1], [0], [1]] * 5
Y1 = ['A', 'B', 'C'] * 5
Y2 = [-1, 0, 1] * 5
Y3 = [0, 1, 2] * 5
x1 = BaggingClassifier(oob_score=True,
random_state=randState).fit(X, Y1).oob_score_
x2 = BaggingClassifier(oob_score=True,
random_state=randState).fit(X, Y2).oob_score_
x3 = BaggingClassifier(oob_score=True,
random_state=randState).fit(X, Y3).oob_score_
assert_equal([x1, x2], [x3, x3])

0 comments on commit 3816f82

Please sign in to comment.