Skip to content

Commit

Permalink
Added min_n_ratings param to LeaveOneOut CV
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jan 14, 2018
1 parent 4ce22ca commit a861696
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 14 deletions.
30 changes: 17 additions & 13 deletions surprise/model_selection/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ class LeaveOneOut():
'''Cross-validation iterator where each user has exactly one rating in the
testset.
Contrary to other cross-validation strategies, random splits do not
Contrary to other cross-validation strategies, ``LeaveOneOut`` does not
guarantee that all folds will be different, although this is still very
likely for sizeable datasets.
Expand All @@ -353,15 +353,19 @@ class LeaveOneOut():
If RandomState instance, this same instance is used as RNG. If
``None``, the current RNG from numpy is used. ``random_state`` is
only used if ``shuffle`` is ``True``. Default is ``None``.
shuffle(bool): Whether to shuffle the ratings in the ``data`` parameter
of the ``split()`` method. Shuffling is not done in-place. Default
is ``True``.
min_n_ratings(int): Minimum number of ratings for each user in the
trainset. E.g. if ``min_n_ratings`` is ``2``, we are sure each user
has at least ``2`` ratings in the trainset (and ``1``) in the
testset. Other users are discarded. Default is ``0``, so some users
(having only one rating) may be in the testset and not in the
trainset.
'''

def __init__(self, n_splits=5, random_state=None):
def __init__(self, n_splits=5, random_state=None, min_n_ratings=0):

self.n_splits = n_splits
self.random_state = random_state
self.min_n_ratings = min_n_ratings

def split(self, data):
'''Generator function to iterate over trainsets and testsets.
Expand All @@ -383,18 +387,18 @@ def split(self, data):

for _ in range(self.n_splits):
# for each user, randomly choose a rating and put it in the
# testset. Note that as some users will have only 1 rating in the
# dataset, this means they won't be trained on.
# testset.
raw_trainset, raw_testset = [], []
for uid, ratings in iteritems(user_ratings):
i = rng.randint(0, len(ratings))
raw_testset.append(ratings[i])
raw_trainset += [rating for (j, rating) in enumerate(ratings)
if j != i]
if len(ratings) > self.min_n_ratings:
i = rng.randint(0, len(ratings))
raw_testset.append(ratings[i])
raw_trainset += [rating for (j, rating)
in enumerate(ratings) if j != i]

if not raw_trainset:
raise ValueError('Each user only has one rating. Cannot '
'Run LOO cross-validation')
raise ValueError('Could not build any trainset. Maybe '
'min_n_ratings is too high?')
trainset = data.construct_trainset(raw_trainset)
testset = data.construct_testset(raw_testset)

Expand Down
15 changes: 14 additions & 1 deletion tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from copy import copy
import numpy as np
from collections import Counter
from six import itervalues

import pytest
from six import itervalues

from surprise import Dataset
from surprise import Reader
Expand Down Expand Up @@ -302,6 +302,19 @@ def test_LeaveOneOut():
cnt = Counter([uid for (uid, _, _) in testset])
assert all(val == 1 for val in itervalues(cnt))

# test the min_n_ratings parameter
loo = LeaveOneOut(min_n_ratings=5)
for trainset, _ in loo.split(data):
assert all(len(ratings) >= 5 for ratings in itervalues(trainset.ur))

loo = LeaveOneOut(min_n_ratings=10)
for trainset, _ in loo.split(data):
assert all(len(ratings) >= 10 for ratings in itervalues(trainset.ur))

loo = LeaveOneOut(min_n_ratings=10000) # too high
with pytest.raises(ValueError):
next(loo.split(data))


def test_PredifinedKFold():

Expand Down

0 comments on commit a861696

Please sign in to comment.