Skip to content

Commit

Permalink
[MRG + 1] Issue#8062: JoblibException thrown when passing "fit_params…
Browse files Browse the repository at this point in the history
…={'sample_… (scikit-learn#8068)

* Issue#8062: JoblibException thrown when passing "fit_params={'sample_weights': weights}" to RandomizedSearchCV with RandomForestClassifier

* Added test for issues scikit-learn#8068 and scikit-learn#8064.

* Clean up with pyflakes.

* Changed cryptic comment.
  • Loading branch information
xor authored and NelleV committed Aug 11, 2017
1 parent 2b90fcb commit 1a18cb5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def fit(self, X, y, sample_weight=None):
# Validate or convert input data
X = check_array(X, accept_sparse="csc", dtype=DTYPE)
y = check_array(y, accept_sparse='csc', ensure_2d=False, dtype=None)
if sample_weight is not None:
sample_weight = check_array(sample_weight, ensure_2d=False)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices.
Expand Down
6 changes: 6 additions & 0 deletions sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,12 @@ def check_class_weights(name):
clf2.fit(iris.data, iris.target, sample_weight)
assert_almost_equal(clf1.feature_importances_, clf2.feature_importances_)

# Using a Python 2.x list as the sample_weight parameter used to raise
# an exception. This test makes sure such code will now run correctly.
clf = ForestClassifier()
sample_weight = [1.] * len(iris.data)
clf.fit(iris.data, iris.target, sample_weight=sample_weight)


def test_class_weights():
for name in FOREST_CLASSIFIERS:
Expand Down

0 comments on commit 1a18cb5

Please sign in to comment.