Skip to content

Commit

Permalink
GridSearchCV now can refit the best algorithm on full dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 7, 2018
1 parent 112c8f9 commit 2b67f93
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 7 deletions.
3 changes: 1 addition & 2 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ TODO


* Update README example before new rewlease, as well as computation times
* grid search should allow the refit param and the test method using the best
estimator

* check conda forge
* make some filtering dataset tools, like remove users/items with less/more
Expand All @@ -14,6 +12,7 @@ TODO
Done:
-----

* Grid search now has the refit param.
* Grid search and cross_validate now allow return_train_score
* Make all fit methods return self. Update docs on building custom algorithms
* Update doc of MF algo to indicate how to retrieve latent factors.
Expand Down
10 changes: 6 additions & 4 deletions doc/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,10 @@ Use cross-validation iterators
------------------------------

For cross-validation, we can use the :func:`cross_validate()
<surprise.model_selection.validation.cross_validate>`
function that does all the hard work for us. But for a better control, we can
also instanciate a cross-validation iterator, and make predictions over each
split using the ``split()`` method of the iterator, and the
<surprise.model_selection.validation.cross_validate>` function that does all
the hard work for us. But for a better control, we can also instanciate a
cross-validation iterator, and make predictions over each split using the
``split()`` method of the iterator, and the
:meth:`test()<surprise.prediction_algorithms.algo_base.AlgoBase.test>` method
of the algorithm. Here is an example where we use a classical K-fold
cross-validation procedure with 3 splits:
Expand All @@ -217,6 +217,8 @@ Result could be, e.g.:
Other cross-validation iterator can be used, like LeaveOneOut or ShuffleSplit.
See all the available iterators :ref:`here <cross_validation_iterators_api>`.
The design of Surprise's cross-validation tools is heavily inspired from the
excellent scikit-learn API.

---------------------

Expand Down
53 changes: 52 additions & 1 deletion surprise/model_selection/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import numpy as np
from joblib import Parallel
from joblib import delayed
from six import string_types

from .split import get_cv
from .validation import fit_and_score
from ..dataset import DatasetUserFolds


class GridSearchCV:
Expand Down Expand Up @@ -38,6 +40,14 @@ class GridSearchCV:
appropriate ``n_splits`` parameter. If ``None``, :class:`KFold
<surprise.model_selection.split.KFold>` is used with
``n_splits=5``.
refit(bool or str): If ``True``, refit the algorithm on the whole
dataset using the set of parameters that gave the best average
performance for the first measure of ``measures``. Other measures
can be used by passing a string (corresponding to the measure
name). Then, you can use the ``test()`` and ``predict()`` methods.
``refit`` can only be used if the ``data`` parameter given to
``fit()`` hasn't been loaded with :meth:`load_from_folds()
<surprise.dataset.Dataset.load_from_folds>`. Default is ``False``.
return_train_measures(bool): Whether to compute performance measures on
the trainsets. If ``True``, the ``cv_results`` attribute will
also contain measures for trainsets. Default is ``False``.
Expand Down Expand Up @@ -92,13 +102,23 @@ class GridSearchCV:
'''

def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
cv=None, return_train_measures=False, n_jobs=-1,
cv=None, refit=False, return_train_measures=False, n_jobs=-1,
pre_dispatch='2*n_jobs', joblib_verbose=0):

self.algo_class = algo_class
self.param_grid = param_grid.copy()
self.measures = [measure.lower() for measure in measures]
self.cv = cv
if isinstance(refit, string_types):
if refit.lower() not in self.measures:
raise ValueError('It looks like the measure you want to use '
'with refit ({}) is not in the measures '
'parameter')
self.refit = refit.lower()
elif refit is True:
self.refit = self.measures[0]
else:
self.refit = False
self.return_train_measures = return_train_measures
self.n_jobs = n_jobs
self.pre_dispatch = pre_dispatch
Expand Down Expand Up @@ -130,6 +150,10 @@ def fit(self, data):
which to evaluate the algorithm, in parallel.
'''

if self.refit and isinstance(data, DatasetUserFolds):
raise ValueError('refit cannot be used when data has been '
'loaded with load_from_folds().')

cv = get_cv(self.cv)

delayed_list = (
Expand Down Expand Up @@ -221,8 +245,35 @@ def fit(self, data):
# cv_results: set params key
cv_results['params'] = self.param_combinations

if self.refit:
best_estimator[self.refit].fit(data.build_full_trainset())

self.best_index = best_index
self.best_params = best_params
self.best_score = best_score
self.best_estimator = best_estimator
self.cv_results = cv_results

def test(self, testset, verbose=False):
'''Call ``test()`` on the estimator with the best found parameters
(according the the ``refit`` parameter). See :meth:`AlgoBase.test()
<surprise.prediction_algorithms.algo_base.AlgoBase.test>`.
Only available if ``refit`` is not ``False``.
'''

if not self.refit:
raise ValueError('refit is False, cannot use test()')
return self.best_estimator[self.refit].test(testset, verbose)

def predict(self, *args):
'''Call ``predict()`` on the estimator with the best found parameters
(according the the ``refit`` parameter). See :meth:`AlgoBase.predict()
<surprise.prediction_algorithms.algo_base.AlgoBase.predict>`.
Only available if ``refit`` is not ``False``.
'''

if not self.refit:
raise ValueError('refit is False, cannot use predict()')
return self.best_estimator[self.refit].predict(*args)
50 changes: 50 additions & 0 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os

import numpy as np
import pytest

from surprise import Dataset
from surprise import Reader
Expand Down Expand Up @@ -139,3 +140,52 @@ def test_cv_results():
assert gs.cv_results['params'][best_index] == gs.best_params['rmse']
best_index = np.argmin(gs.cv_results['rank_test_mae'])
assert gs.cv_results['params'][best_index] == gs.best_params['mae']


def test_refit():

data_file = os.path.join(os.path.dirname(__file__), './u1_ml100k_test')
data = Dataset.load_from_file(data_file, Reader('ml-100k'))

param_grid = {'n_epochs': [5], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6], 'n_factors': [2]}

# assert gs.fit() and gs.test will use best estimator for mae (first
# appearing in measures)
gs = GridSearchCV(SVD, param_grid, measures=['mae', 'rmse'], cv=2,
refit=True)
gs.fit(data)
gs_preds = gs.test(data.construct_testset(data.raw_ratings))
mae_preds = gs.best_estimator['mae'].test(
data.construct_testset(data.raw_ratings))
assert gs_preds == mae_preds

# assert gs.fit() and gs.test will use best estimator for rmse
gs = GridSearchCV(SVD, param_grid, measures=['mae', 'rmse'], cv=2,
refit='rmse')
gs.fit(data)
gs_preds = gs.test(data.construct_testset(data.raw_ratings))
rmse_preds = gs.best_estimator['rmse'].test(
data.construct_testset(data.raw_ratings))
assert gs_preds == rmse_preds
# test that predict() can be called
gs.predict(2, 4)

# assert test() and predict() cannot be used when refit is false
gs = GridSearchCV(SVD, param_grid, measures=['mae', 'rmse'], cv=2,
refit=False)
gs.fit(data)
with pytest.raises(ValueError):
gs_preds = gs.test(data.construct_testset(data.raw_ratings))
with pytest.raises(ValueError):
gs.predict('1', '2')

# test that error is raised if used with load_from_folds
train_file = os.path.join(os.path.dirname(__file__), './u1_ml100k_train')
test_file = os.path.join(os.path.dirname(__file__), './u1_ml100k_test')
data = Dataset.load_from_folds([(train_file, test_file)],
Reader('ml-100k'))
gs = GridSearchCV(SVD, param_grid, measures=['mae', 'rmse'], cv=2,
refit=True)
with pytest.raises(ValueError):
gs.fit(data)

0 comments on commit 2b67f93

Please sign in to comment.