Skip to content

Commit

Permalink
[MRG+1] Add classes_ parameter to hyperparameter CV classes (scikit-l…
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen Hoover authored and Sundrique committed Jun 14, 2017
1 parent 40a49e0 commit b6a05cd
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
9 changes: 6 additions & 3 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@ Enhancements
now uses significantly less memory when assigning data points to their
nearest cluster center. :issue:`7721` by :user:`Jon Crall <Erotemic>`.

- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`
that matches the ``classes_`` attribute of ``best_estimator_``. :issue:`7661`
by :user:`Alyssa Batula <abatula>` and :user:`Dylan Werner-Meier <unautre>`.
- Added ``classes_`` attribute to :class:`model_selection.GridSearchCV`,
:class:`model_selection.RandomizedSearchCV`, :class:`grid_search.GridSearchCV`,
and :class:`grid_search.RandomizedSearchCV` that matches the ``classes_``
attribute of ``best_estimator_``. :issue:`7661` and :issue:`8295`
by :user:`Alyssa Batula <abatula>`, :user:`Dylan Werner-Meier <unautre>`,
and :user:`Stephen Hoover <stephen-hoover>`.

- The ``min_weight_fraction_leaf`` constraint in tree construction is now
more efficient, taking a fast path to declare a node a leaf if its weight
Expand Down
5 changes: 5 additions & 0 deletions sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ def inverse_transform(self, Xt):
self._check_is_fitted('inverse_transform')
return self.best_estimator_.transform(Xt)

@property
def classes_(self):
self._check_is_fitted("classes_")
return self.best_estimator_.classes_

def fit(self, X, y=None, groups=None, **fit_params):
"""Run fit with all sets of parameters.
Expand Down
30 changes: 29 additions & 1 deletion sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import Imputer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier
from sklearn.linear_model import Ridge, SGDClassifier

from sklearn.model_selection.tests.common import OneTimeSplitter

Expand All @@ -73,6 +73,7 @@ def __init__(self, foo_param=0):

def fit(self, X, Y):
assert_true(len(X) == len(Y))
self.classes_ = np.unique(Y)
return self

def predict(self, T):
Expand Down Expand Up @@ -323,6 +324,33 @@ def test_grid_search_groups():
gs.fit(X, y)


def test_classes__property():
# Test that classes_ property matches best_estimator_.classes_
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)
Cs = [.1, 1, 10]

grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
grid_search.fit(X, y)
assert_array_equal(grid_search.best_estimator_.classes_,
grid_search.classes_)

# Test that regressors do not have a classes_ attribute
grid_search = GridSearchCV(Ridge(), {'alpha': [1.0, 2.0]})
grid_search.fit(X, y)
assert_false(hasattr(grid_search, 'classes_'))

# Test that the grid searcher has no classes_ attribute before it's fit
grid_search = GridSearchCV(LinearSVC(random_state=0), {'C': Cs})
assert_false(hasattr(grid_search, 'classes_'))

# Test that the grid searcher has no classes_ attribute without a refit
grid_search = GridSearchCV(LinearSVC(random_state=0),
{'C': Cs}, refit=False)
grid_search.fit(X, y)
assert_false(hasattr(grid_search, 'classes_'))


def test_trivial_cv_results_attr():
# Test search over a "grid" with only one point.
# Non-regression test: grid_scores_ wouldn't be set by GridSearchCV.
Expand Down
16 changes: 13 additions & 3 deletions sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from sklearn.datasets import make_multilabel_classification

from sklearn.model_selection.tests.common import OneTimeSplitter
from sklearn.model_selection import GridSearchCV


try:
Expand Down Expand Up @@ -914,7 +915,7 @@ def test_cross_val_predict_sparse_prediction():
assert_array_almost_equal(preds_sparse, preds)


def test_cross_val_predict_with_method():
def check_cross_val_predict_with_method(est):
iris = load_iris()
X, y = iris.data, iris.target
X, y = shuffle(X, y, random_state=0)
Expand All @@ -924,8 +925,6 @@ def test_cross_val_predict_with_method():

methods = ['decision_function', 'predict_proba', 'predict_log_proba']
for method in methods:
est = LogisticRegression()

predictions = cross_val_predict(est, X, y, method=method)
assert_equal(len(predictions), len(y))

Expand Down Expand Up @@ -955,6 +954,17 @@ def test_cross_val_predict_with_method():
assert_array_equal(predictions, predictions_ystr)


def test_cross_val_predict_with_method():
check_cross_val_predict_with_method(LogisticRegression())


def test_gridsearchcv_cross_val_predict_with_method():
est = GridSearchCV(LogisticRegression(random_state=42),
{'C': [0.1, 1]},
cv=2)
check_cross_val_predict_with_method(est)


def get_expected_predictions(X, y, cv, classes, est, method):

expected_predictions = np.zeros([len(y), classes])
Expand Down

0 comments on commit b6a05cd

Please sign in to comment.