Skip to content

Commit

Permalink
Put helper seed_and_eval function outside GridSearch class as it seem…
Browse files Browse the repository at this point in the history
…s that the method could not be pickled with some versions (hence travis failed)
  • Loading branch information
NicolasHug committed Oct 24, 2017
1 parent b92e3b8 commit 1d4322e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
3 changes: 2 additions & 1 deletion doc/source/evaluate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ evaluate module

.. automodule:: surprise.evaluate
:members:
:exclude-members: CaseInsensitiveDefaultDict, CaseInsensitiveDefaultDictForBestResults
:exclude-members: CaseInsensitiveDefaultDict,
CaseInsensitiveDefaultDictForBestResults, seed_and_eval
24 changes: 13 additions & 11 deletions surprise/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,6 @@ def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
self.param_combinations = [dict(zip(self.param_grid, v)) for v in
product(*self.param_grid.values())]

def eval_helper(self, *args):
"""Helper function that calls evaluate.evaluate() *after* having seeded
the RNG. RNG seeding is mandatory since evalute() is called by
different processes."""

random.seed(self.seed)
return evaluate(*args, verbose=0)

def evaluate(self, data):
"""Runs the grid search on dataset.
Expand All @@ -232,9 +224,10 @@ def evaluate(self, data):
print(combination)

delayed_list = (
delayed(self.eval_helper)(self.algo_class(**combination),
data,
self.measures)
delayed(seed_and_eval)(self.seed,
self.algo_class(**combination),
data,
self.measures)
for combination in self.param_combinations
)
performances_list = Parallel(n_jobs=self.n_jobs,
Expand Down Expand Up @@ -308,3 +301,12 @@ def print_perf(performances):
for (key, vals) in iteritems(performances))

print(s)


def seed_and_eval(seed, *args):
"""Helper function that calls evaluate.evaluate() *after* having seeded
the RNG. RNG seeding is mandatory since evalute() is called by
different processes."""

random.seed(seed)
return evaluate(*args, verbose=0)

0 comments on commit 1d4322e

Please sign in to comment.