Skip to content

Commit

Permalink
Changed a bit the shuffling part of split()
Browse files Browse the repository at this point in the history
- it is now done in split() and not when raw_folds() is called, which is
more natural
- Added some tests / updated old ones
  • Loading branch information
NicolasHug committed Nov 21, 2017
1 parent 9ab72e5 commit 5d4011a
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Gemfile.lock
_site

.coverage
tags
20 changes: 10 additions & 10 deletions surprise/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def load_builtin(cls, name='ml-100k'):
<DatasetAutoFolds.split>` method. See an example in the :ref:`User
Guide <load_builtin_example>`.
Args:
name(:obj:`string`): The name of the built-in dataset to load.
Accepted values are 'ml-100k', 'ml-1m', and 'jester'.
Expand Down Expand Up @@ -318,8 +317,7 @@ class DatasetAutoFolds(Dataset):
def __init__(self, ratings_file=None, reader=None, df=None):

Dataset.__init__(self, reader)
self.n_folds = 5
self.shuffle = True
self.has_been_split = False # flag indicating if split() was called.

if ratings_file is not None:
self.ratings_file = ratings_file
Expand Down Expand Up @@ -347,16 +345,12 @@ def build_full_trainset(self):

def raw_folds(self):

if self.shuffle:
random.shuffle(self.raw_ratings)
self.shuffle = False # set to false for future calls to raw_folds
if not self.has_been_split:
self.split()

def k_folds(seq, n_folds):
"""Inspired from scikit learn KFold method."""

if n_folds > len(seq) or n_folds < 2:
raise ValueError('Incorrect value for n_folds.')

start, stop = 0, 0
for fold_i in range(n_folds):
start = stop
Expand Down Expand Up @@ -386,8 +380,14 @@ def split(self, n_folds=5, shuffle=True):
experiment is run. Default is ``True``.
"""

if n_folds > len(self.raw_ratings) or n_folds < 2:
raise ValueError('Incorrect value for n_folds.')

if shuffle:
random.shuffle(self.raw_ratings)

self.n_folds = n_folds
self.shuffle = shuffle
self.has_been_split = True


class Reader():
Expand Down
62 changes: 39 additions & 23 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import os
import random

import pytest
import pandas as pd
Expand All @@ -14,6 +15,7 @@
from surprise import Reader


random.seed(1)
reader = Reader(line_format='user item rating', sep=' ', skip_lines=3,
rating_scale=(1, 5))

Expand Down Expand Up @@ -41,51 +43,65 @@ def test_build_full_trainset():
assert trainset.n_items == 2


def test_split():
"""Test the split method."""
def test_no_call_to_split():
"""Ensure, as mentioned in the split() docstring, that even if split is not
called then the data is split with 5 folds after being shuffled."""

custom_dataset_path = (os.path.dirname(os.path.realpath(__file__)) +
'/custom_dataset')
data = Dataset.load_from_file(file_path=custom_dataset_path, reader=reader)

# Test n_folds parameter
data.split(5)
assert len(list(data.folds())) == 5

with pytest.raises(ValueError):
data.split(10)
for fold in data.folds():
pass
# make sure data has been shuffled. If not shuffled, the users in the
# testsets would be 0, 1, 2... 4 (in that order).
users = [int(testset[0][0][-1]) for (_, testset) in data.folds()]
assert users != list(range(5))

with pytest.raises(ValueError):
data.split(1)
for fold in data.folds():
pass

def test_split():
"""Test the split method."""

custom_dataset_path = (os.path.dirname(os.path.realpath(__file__)) +
'/custom_dataset')
data = Dataset.load_from_file(file_path=custom_dataset_path, reader=reader)

# Test the shuffle parameter
# Make sure data has not been shuffled. If not shuffled, the users in the
# testsets are be 0, 1, 2... 4 (in that order).
data.split(n_folds=5, shuffle=False)
users = [int(testset[0][0][-1]) for (_, testset) in data.folds()]
assert users == list(range(5))

# Test the shuffle parameter
# Make sure that when called two times without shuffling, folds are the
# same.
data.split(n_folds=3, shuffle=False)
testsets_a = [testset for (_, testset) in data.folds()]
data.split(n_folds=3, shuffle=False)
testsets_b = [testset for (_, testset) in data.folds()]
assert testsets_a == testsets_b

# We'll shuffle and check that folds are now different. There's a chance
# that they're still the same, just by lack of luck. If after 10000 tries
# the're still the same, there's a high probability that our code is
# faulty. If we're very (very very very) unlucky, it may fail though (or
# loop for eternity).
i = 0
while testsets_a == testsets_b:
data.split(n_folds=3, shuffle=True)
testsets_b = [testset for (_, testset) in data.folds()]
i += 1
assert i < 10000
# We'll now shuffle b and check that folds are different.
data.split(n_folds=3, shuffle=True)
testsets_b = [testset for (_, testset) in data.folds()]
assert testsets_a != testsets_b

# Ensure that folds are the same if split is not called again
testsets_a = [testset for (_, testset) in data.folds()]
testsets_b = [testset for (_, testset) in data.folds()]
assert testsets_a == testsets_b

# Test n_folds parameter
data.split(5)
assert len(list(data.folds())) == 5

with pytest.raises(ValueError):
data.split(10) # Too big (greater than number of ratings)

with pytest.raises(ValueError):
data.split(1) # Too low (must be >= 2)


def test_trainset_testset():
"""Test the construct_trainset and construct_testset methods."""
Expand Down

0 comments on commit 5d4011a

Please sign in to comment.