Skip to content

Commit

Permalink
Seeding RNG before calling evaluate() in GridSearch. Fixes #95.
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 24, 2017
1 parent 1a12f8a commit 712b670
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 5 deletions.
24 changes: 20 additions & 4 deletions surprise/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import os
from itertools import product
import random

import numpy as np
from six import iteritems
Expand Down Expand Up @@ -144,6 +145,9 @@ class GridSearch:
as in ``'2*n_jobs'``.
Default is ``'2*n_jobs'``.
seed(int): The value to use as seed for RNG. It will determine how
splits are defined. If ``None``, the current time since epoch is
used. Default is ``None``.
verbose(bool): Level of verbosity. If ``False``, nothing is printed. If
``True``, The mean values of each measure are printed along for
each parameter combination. Default is ``True``.
Expand All @@ -170,7 +174,7 @@ class GridSearch:
"""

def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
n_jobs=-1, pre_dispatch='2*n_jobs', verbose=1,
n_jobs=-1, pre_dispatch='2*n_jobs', seed=None, verbose=1,
joblib_verbose=0):
self.best_params = CaseInsensitiveDefaultDict(list)
self.best_index = CaseInsensitiveDefaultDict(list)
Expand All @@ -182,6 +186,7 @@ def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
self.measures = [measure.upper() for measure in measures]
self.n_jobs = n_jobs
self.pre_dispatch = pre_dispatch
self.seed = seed if seed is not None else int(time.time())
self.verbose = verbose
self.joblib_verbose = joblib_verbose

Expand All @@ -202,6 +207,14 @@ def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
self.param_combinations = [dict(zip(self.param_grid, v)) for v in
product(*self.param_grid.values())]

def eval_helper(self, *args):
"""Helper function that calls evaluate.evaluate() *after* having seeded
the RNG. RNG seeding is mandatory since evalute() is called by
different processes."""

random.seed(self.seed)
return evaluate(*args, verbose=0)

def evaluate(self, data):
"""Runs the grid search on dataset.
Expand All @@ -218,9 +231,12 @@ def evaluate(self, data):
for combination in self.param_combinations:
print(combination)

delayed_list = (delayed(evaluate)(self.algo_class(**combination), data,
self.measures, verbose=0)
for combination in self.param_combinations)
delayed_list = (
delayed(self.eval_helper)(self.algo_class(**combination),
data,
self.measures)
for combination in self.param_combinations
)
performances_list = Parallel(n_jobs=self.n_jobs,
pre_dispatch=self.pre_dispatch,
verbose=self.joblib_verbose)(delayed_list)
Expand Down
29 changes: 28 additions & 1 deletion tests/test_grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@


def test_grid_search_cv_results():
"""Ensure that the number of parameter combinations is correct."""
param_grid = {'n_epochs': [1, 2], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6], 'n_factors': [1], 'init_std_dev': [0]}
grid_search = GridSearch(SVD, param_grid)
Expand All @@ -33,6 +34,7 @@ def test_grid_search_cv_results():


def test_measure_is_not_case_sensitive():
"""Ensure that all best_* dictionaries are case insensitive."""
param_grid = {'n_epochs': [1], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6], 'n_factors': [1], 'init_std_dev': [0]}
grid_search = GridSearch(SVD, param_grid, measures=['FCP', 'mae', 'rMSE'])
Expand All @@ -43,6 +45,8 @@ def test_measure_is_not_case_sensitive():


def test_best_estimator():
"""Ensure that the best estimator is the one giving the best score (by
re-running it)"""
param_grid = {'n_epochs': [5], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6], 'n_factors': [1], 'init_std_dev': [0]}
grid_search = GridSearch(SVD, param_grid, measures=['FCP', 'mae', 'rMSE'])
Expand All @@ -54,7 +58,8 @@ def test_best_estimator():

def test_dict_parameters():
"""Dict parameters like bsl_options and sim_options require special
treatment. We here test both in one shot with KNNBaseline."""
treatment in the param_grid argument. We here test both in one shot with
KNNBaseline."""

param_grid = {'bsl_options': {'method': ['als', 'sgd'],
'reg': [1, 2]},
Expand All @@ -68,3 +73,25 @@ def test_dict_parameters():
measures=['FCP', 'mae', 'rMSE'])
grid_search.evaluate(data)
assert len(grid_search.cv_results['params']) == 32


def test_same_splits():
"""Ensure that all parameter combinations are tested on the same splits (we
check that average RMSE scores are the same, which should be enough)."""

data_file = os.path.join(os.path.dirname(__file__), './u1_ml100k_train')
data = Dataset.load_from_file(data_file, reader=Reader('ml-100k'))
data.split(3)

# all RMSE should be the same (as param combinations are the same)
param_grid = {'n_epochs': [1, 1], 'lr_all': [.5, .5]}
grid_search = GridSearch(SVD, param_grid, measures=['RMSE'], n_jobs=-1)
grid_search.evaluate(data)

rmse_scores = [s['RMSE'] for s in grid_search.cv_results['scores']]
assert len(set(rmse_scores)) == 1 # assert rmse_scores are all equal

# evaluate grid search again, to make sure that splits are still the same.
grid_search.evaluate(data)
rmse_scores += [s['RMSE'] for s in grid_search.cv_results['scores']]
assert len(set(rmse_scores)) == 1

0 comments on commit 712b670

Please sign in to comment.