Skip to content

Commit

Permalink
Update cv_results example
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 9, 2018
1 parent c52385f commit 8bc21cf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 21 deletions.
59 changes: 40 additions & 19 deletions doc/source/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ please:
'user_based': [False]}
}
.. _cv_results_example:

For further analysis, the ``cv_results`` attribute has all the needed
information and can be imported in a pandas dataframe:

Expand All @@ -320,30 +322,49 @@ information and can be imported in a pandas dataframe:
:name: grid_search_usage3.py
:lines: 33

In our example, the ``cv_results`` attribute looks like this:
In our example, the ``cv_results`` attribute looks like this (floats are
formatted):

.. parsed-literal::
'split0_test_rmse': [1.0, 1.01, 0.98, 0.99, 0.98, 0.99, 0.97, 0.98]
'split1_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'split2_test_rmse': [0.99, 1.0, 0.97, 0.98, 0.97, 0.98, 0.96, 0.97]
'mean_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'std_test_rmse': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
'rank_test_rmse': [7 8 3 5 4 6 1 2]
'split0_test_mae': [0.81, 0.82, 0.78, 0.8, 0.79, 0.8, 0.78, 0.79]
'split1_test_mae': [0.81, 0.82, 0.78, 0.79, 0.79, 0.8, 0.77, 0.79]
'split2_test_mae': [0.8, 0.81, 0.78, 0.79, 0.78, 0.79, 0.77, 0.78]
'mean_test_mae': [0.81, 0.81, 0.78, 0.79, 0.79, 0.8, 0.77, 0.78]
'std_test_mae': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
'rank_test_mae': [7 8 2 5 4 6 1 3]
'mean_fit_time': [1.5, 1.58, 1.58, 1.51, 2.99, 3.01, 3.05, 3.06]
'std_fit_time': [0.02, 0.03, 0.05, 0.01, 0.02, 0.02, 0.05, 0.04]
'mean_test_time': [0.45, 0.47, 0.44, 0.45, 0.44, 0.47, 0.47, 0.31]
'std_test_time': [0.0, 0.04, 0.01, 0.0, 0.01, 0.04, 0.03, 0.09]
'params': [{'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.4}, {'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.6}, {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.4}, {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.6}, {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.4}, {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.6}, {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.4}, {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.6}]
'split0_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'split1_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'split2_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'mean_test_rmse': [1.0, 1.0, 0.97, 0.98, 0.98, 0.99, 0.96, 0.97]
'std_test_rmse': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
'rank_test_rmse': [7 8 3 5 4 6 1 2]
'split0_test_mae': [0.81, 0.82, 0.78, 0.79, 0.79, 0.8, 0.77, 0.79]
'split1_test_mae': [0.8, 0.81, 0.78, 0.79, 0.78, 0.79, 0.77, 0.78]
'split2_test_mae': [0.81, 0.81, 0.78, 0.79, 0.78, 0.8, 0.77, 0.78]
'mean_test_mae': [0.81, 0.81, 0.78, 0.79, 0.79, 0.8, 0.77, 0.78]
'std_test_mae': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
'rank_test_mae': [7 8 2 5 4 6 1 3]
'mean_fit_time': [1.53, 1.52, 1.53, 1.53, 3.04, 3.05, 3.06, 3.02]
'std_fit_time': [0.03, 0.04, 0.0, 0.01, 0.04, 0.01, 0.06, 0.01]
'mean_test_time': [0.46, 0.45, 0.44, 0.44, 0.47, 0.49, 0.46, 0.34]
'std_test_time': [0.0, 0.01, 0.01, 0.0, 0.03, 0.06, 0.01, 0.08]
'params': [{'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.4}, {'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.6}, {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.4}, {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.6}, {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.4}, {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.6}, {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.4}, {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.6}]
'param_n_epochs': [5, 5, 5, 5, 10, 10, 10, 10]
'param_lr_all': [0.0, 0.0, 0.01, 0.01, 0.0, 0.0, 0.01, 0.01]
'param_reg_all': [0.4, 0.6, 0.4, 0.6, 0.4, 0.6, 0.4, 0.6]
As you can see, each list has the same size of the number of parameter
combination.
combination. It corresponds to the following table:

================== ================== ================== ================ =============== ================ ================= ================= ================= =============== ============== =============== =============== ============== ================ =============== ================================================= ================ ============== ===============
split0_test_rmse split1_test_rmse split2_test_rmse mean_test_rmse std_test_rmse rank_test_rmse split0_test_mae split1_test_mae split2_test_mae mean_test_mae std_test_mae rank_test_mae mean_fit_time std_fit_time mean_test_time std_test_time params param_n_epochs param_lr_all param_reg_all
================== ================== ================== ================ =============== ================ ================= ================= ================= =============== ============== =============== =============== ============== ================ =============== ================================================= ================ ============== ===============
0.99775 0.997744 0.996378 0.997291 0.000645508 7 0.807862 0.804626 0.805282 0.805923 0.00139657 7 1.53341 0.0305216 0.455831 0.000922113 {'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.4} 5 0.002 0.4
1.00381 1.00304 1.00257 1.00314 0.000508358 8 0.816559 0.812905 0.813772 0.814412 0.00155866 8 1.5199 0.0367117 0.451068 0.00938646 {'n_epochs': 5, 'lr_all': 0.002, 'reg_all': 0.6} 5 0.002 0.6
0.973524 0.973595 0.972495 0.973205 0.000502609 3 0.783361 0.780242 0.78067 0.781424 0.00138049 2 1.53449 0.00496203 0.441558 0.00529696 {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.4} 5 0.005 0.4
0.98229 0.982059 0.981486 0.981945 0.000338056 5 0.794481 0.790781 0.79186 0.792374 0.00155377 5 1.52739 0.00859185 0.44463 0.000888907 {'n_epochs': 5, 'lr_all': 0.005, 'reg_all': 0.6} 5 0.005 0.6
0.978034 0.978407 0.976919 0.977787 0.000632049 4 0.787643 0.784723 0.784957 0.785774 0.00132486 4 3.03572 0.0431101 0.466606 0.0254965 {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.4} 10 0.002 0.4
0.986263 0.985817 0.985004 0.985695 0.000520899 6 0.798218 0.794457 0.795373 0.796016 0.00160135 6 3.0544 0.00636185 0.488357 0.0576194 {'n_epochs': 10, 'lr_all': 0.002, 'reg_all': 0.6} 10 0.002 0.6
0.963751 0.963463 0.962676 0.963297 0.000454661 1 0.774036 0.770548 0.771588 0.772057 0.00146201 1 3.0636 0.0597982 0.456484 0.00510321 {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.4} 10 0.005 0.4
0.973605 0.972868 0.972765 0.973079 0.000374222 2 0.78607 0.781918 0.783537 0.783842 0.00170855 3 3.01907 0.011834 0.338839 0.075346 {'n_epochs': 10, 'lr_all': 0.005, 'reg_all': 0.6} 10 0.005 0.6
================== ================== ================== ================ =============== ================ ================= ================= ================= =============== ============== =============== =============== ============== ================ =============== ================================================= ================ ============== ===============



Command line usage
------------------
Expand Down
40 changes: 40 additions & 0 deletions examples/generate_grid_search_cv_results_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""
This module is used for generating the doc tables about the
GridSearchCV.cv_results attribute.
"""

from __future__ import (absolute_import, division, print_function,
unicode_literals)

from tabulate import tabulate
from six import iteritems

from surprise import SVD
from surprise import Dataset
from surprise.model_selection import GridSearchCV

# Use movielens-100K
data = Dataset.load_builtin('ml-100k')

param_grid = {'n_epochs': [5, 10], 'lr_all': [0.002, 0.005],
'reg_all': [0.4, 0.6]}
gs = GridSearchCV(SVD, param_grid, measures=['rmse', 'mae'], cv=3)

gs.fit(data)

table = [[] for _ in range(len(gs.cv_results['params']))]
for i in range(len(gs.cv_results['params'])):
for key in gs.cv_results.keys():
table[i].append(gs.cv_results[key][i])

header = gs.cv_results.keys()
print(tabulate(table, header, tablefmt="rst"))

print()

for key, val in iteritems(gs.cv_results):
print('{:<20}'.format("'" + key + "':"), end='')
if isinstance(val[0], float):
print([float('{:.2f}'.format(f)) for f in val])
else:
print(val)
8 changes: 6 additions & 2 deletions surprise/model_selection/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ class GridSearchCV:
cv_results (dict of arrays):
A dict that contains accuracy measures over all splits, as well as
train and test time for each parameter combination. Can be imported
into a pandas `DataFrame`.
into a pandas `DataFrame` (see :ref:`example
<cv_results_example>`).
'''

def __init__(self, algo_class, param_grid, measures=['rmse', 'mae'],
Expand Down Expand Up @@ -242,8 +243,11 @@ def fit(self, data):
cv_results['mean_{}_time'.format(s)] = times.mean(axis=1)
cv_results['std_{}_time'.format(s)] = times.std(axis=1)

# cv_results: set params key
# cv_results: set params key and each param_* values
cv_results['params'] = self.param_combinations
for param in self.param_combinations[0]:
cv_results['param_' + param] = [comb[param] for comb in
self.param_combinations]

if self.refit:
best_estimator[self.refit].fit(data.build_full_trainset())
Expand Down

0 comments on commit 8bc21cf

Please sign in to comment.