Skip to content

Commit

Permalink
[MRG+1] Repeated K-Fold and Repeated Stratified K-Fold (scikit-learn#…
Browse files Browse the repository at this point in the history
…8120)

* Add _RepeatedSplits and RepeatedKFold class

* Add RepeatedStratifiedKFold and doc for repeated cvs

* Change default value of n_repeats

* Change input parameters of repeated cv constructor to n_splits, n_repeats, random_state

* Generate random states in split function rather than store it beforehand

* Doc changes, inheriting RepeatedKFold, RepeatedStratifiedKFold from _RepeatedSplits and other review changes

* Remove blank line, put testcases for deterministic split in loop and add StopIteration check in testcase

* Using rng directly as random_state param to create cv instance and added a check for cvargs

* Fix pep8 warnings

* Changing default values for n_splits and n_repeats and add entry in changelog

* Adding name to the feature

* Missing space
  • Loading branch information
neerajgangwar authored and Sundrique committed Jun 14, 2017
1 parent 5c92a20 commit f149d30
Show file tree
Hide file tree
Showing 6 changed files with 286 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ Splitter Classes
model_selection.LeavePGroupsOut
model_selection.LeaveOneOut
model_selection.LeavePOut
model_selection.RepeatedKFold
model_selection.RepeatedStratifiedKFold
model_selection.ShuffleSplit
model_selection.GroupShuffleSplit
model_selection.StratifiedShuffleSplit
Expand Down
31 changes: 31 additions & 0 deletions doc/modules/cross_validation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,33 @@ Thus, one can create the training/test sets using numpy indexing::
>>> X_train, X_test, y_train, y_test = X[train], X[test], y[train], y[test]


Repeated K-Fold
---------------

:class:`RepeatedKFold` repeats K-Fold n times. It can be used when one
requires to run :class:`KFold` n times, producing different splits in
each repetition.

Example of 2-fold K-Fold repeated 2 times::

>>> import numpy as np
>>> from sklearn.model_selection import RepeatedKFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> random_state = 12883823
>>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=random_state)
>>> for train, test in rkf.split(X):
... print("%s %s" % (train, test))
...
[2 3] [0 1]
[0 1] [2 3]
[0 2] [1 3]
[1 3] [0 2]


Similarly, :class:`RepeatedStratifiedKFold` repeats Stratified K-Fold n times
with different randomization in each repetition.


Leave One Out (LOO)
-------------------

Expand Down Expand Up @@ -409,6 +436,10 @@ two slightly unbalanced classes::
[0 1 3 4 5 8 9] [2 6 7]
[0 1 2 4 5 6 7] [3 8 9]

:class:`RepeatedStratifiedKFold` can be used to repeat Stratified K-Fold n times
with different randomization in each repetition.


Stratified Shuffle Split
------------------------

Expand Down
6 changes: 6 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ New features
Kullback-Leibler divergence and the Itakura-Saito divergence.
By `Tom Dupre la Tour`_.

- Added the :class:`sklearn.model_selection.RepeatedKFold` and
:class:`sklearn.model_selection.RepeatedStratifiedKFold`.
:issue:`8120` by `Neeraj Gangwar`_.

- Added :func:`metrics.mean_squared_log_error`, which computes
the mean square error of the logarithmic transformation of targets,
particularly useful for targets with an exponential trend.
Expand Down Expand Up @@ -5004,3 +5008,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Vincent Pham: https://github.com/vincentpham1991

.. _Denis Engemann: http://denis-engemann.de

.. _Neeraj Gangwar: http://neerajgangwar.in
4 changes: 4 additions & 0 deletions sklearn/model_selection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from ._split import LeaveOneOut
from ._split import LeavePGroupsOut
from ._split import LeavePOut
from ._split import RepeatedKFold
from ._split import RepeatedStratifiedKFold
from ._split import ShuffleSplit
from ._split import GroupShuffleSplit
from ._split import StratifiedShuffleSplit
Expand Down Expand Up @@ -36,6 +38,8 @@
'LeaveOneOut',
'LeavePGroupsOut',
'LeavePOut',
'RepeatedKFold',
'RepeatedStratifiedKFold',
'ParameterGrid',
'ParameterSampler',
'PredefinedSplit',
Expand Down
171 changes: 171 additions & 0 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
'LeaveOneOut',
'LeavePGroupsOut',
'LeavePOut',
'RepeatedStratifiedKFold',
'RepeatedKFold',
'ShuffleSplit',
'GroupShuffleSplit',
'StratifiedKFold',
Expand Down Expand Up @@ -397,6 +399,8 @@ class KFold(_BaseKFold):
classification tasks).
GroupKFold: K-fold iterator variant with non-overlapping groups.
RepeatedKFold: Repeats K-Fold n times.
"""

def __init__(self, n_splits=3, shuffle=False,
Expand Down Expand Up @@ -553,6 +557,9 @@ class StratifiedKFold(_BaseKFold):
All the folds have size ``trunc(n_samples / n_splits)``, the last one has
the complementary.
See also
--------
RepeatedStratifiedKFold: Repeats Stratified K-Fold n times.
"""

def __init__(self, n_splits=3, shuffle=False, random_state=None):
Expand Down Expand Up @@ -913,6 +920,170 @@ def get_n_splits(self, X, y, groups):
return int(comb(len(np.unique(groups)), self.n_groups, exact=True))


class _RepeatedSplits(with_metaclass(ABCMeta)):
"""Repeated splits for an arbitrary randomized CV splitter.
Repeats splits for cross-validators n times with different randomization
in each repetition.
Parameters
----------
cv : callable
Cross-validator class.
n_repeats : int, default=10
Number of times cross-validator needs to be repeated.
random_state : None, int or RandomState, default=None
Random state to be used to generate random state for each
repetition.
**cvargs : additional params
Constructor parameters for cv. Must not contain random_state
and shuffle.
"""
def __init__(self, cv, n_repeats=10, random_state=None, **cvargs):
if not isinstance(n_repeats, (np.integer, numbers.Integral)):
raise ValueError("Number of repetitions must be of Integral type.")

if n_repeats <= 1:
raise ValueError("Number of repetitions must be greater than 1.")

if any(key in cvargs for key in ('random_state', 'shuffle')):
raise ValueError(
"cvargs must not contain random_state or shuffle.")

self.cv = cv
self.n_repeats = n_repeats
self.random_state = random_state
self.cvargs = cvargs

def split(self, X, y=None, groups=None):
"""Generates indices to split data into training and test set.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data, where n_samples is the number of samples
and n_features is the number of features.
y : array-like, of length n_samples
The target variable for supervised learning problems.
groups : array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into
train/test set.
Returns
-------
train : ndarray
The training set indices for that split.
test : ndarray
The testing set indices for that split.
"""
n_repeats = self.n_repeats
rng = check_random_state(self.random_state)

for idx in range(n_repeats):
cv = self.cv(random_state=rng, shuffle=True,
**self.cvargs)
for train_index, test_index in cv.split(X, y, groups):
yield train_index, test_index


class RepeatedKFold(_RepeatedSplits):
"""Repeated K-Fold cross validator.
Repeats K-Fold n times with different randomization in each repetition.
Read more in the :ref:`User Guide <cross_validation>`.
Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.
n_repeats : int, default=10
Number of times cross-validator needs to be repeated.
random_state : None, int or RandomState, default=None
Random state to be used to generate random state for each
repetition.
Examples
--------
>>> from sklearn.model_selection import RepeatedKFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([0, 0, 1, 1])
>>> rkf = RepeatedKFold(n_splits=2, n_repeats=2, random_state=2652124)
>>> for train_index, test_index in rkf.split(X):
... print("TRAIN:", train_index, "TEST:", test_index)
... X_train, X_test = X[train_index], X[test_index]
... y_train, y_test = y[train_index], y[test_index]
...
TRAIN: [0 1] TEST: [2 3]
TRAIN: [2 3] TEST: [0 1]
TRAIN: [1 2] TEST: [0 3]
TRAIN: [0 3] TEST: [1 2]
See also
--------
RepeatedStratifiedKFold: Repeates Stratified K-Fold n times.
"""
def __init__(self, n_splits=5, n_repeats=10, random_state=None):
super(RepeatedKFold, self).__init__(
KFold, n_repeats, random_state, n_splits=n_splits)


class RepeatedStratifiedKFold(_RepeatedSplits):
"""Repeated Stratified K-Fold cross validator.
Repeats Stratified K-Fold n times with different randomization in each
repetition.
Read more in the :ref:`User Guide <cross_validation>`.
Parameters
----------
n_splits : int, default=5
Number of folds. Must be at least 2.
n_repeats : int, default=10
Number of times cross-validator needs to be repeated.
random_state : None, int or RandomState, default=None
Random state to be used to generate random state for each
repetition.
Examples
--------
>>> from sklearn.model_selection import RepeatedStratifiedKFold
>>> X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
>>> y = np.array([0, 0, 1, 1])
>>> rskf = RepeatedStratifiedKFold(n_splits=2, n_repeats=2,
... random_state=36851234)
>>> for train_index, test_index in rskf.split(X, y):
... print("TRAIN:", train_index, "TEST:", test_index)
... X_train, X_test = X[train_index], X[test_index]
... y_train, y_test = y[train_index], y[test_index]
...
TRAIN: [1 2] TEST: [0 3]
TRAIN: [0 3] TEST: [1 2]
TRAIN: [1 3] TEST: [0 2]
TRAIN: [0 2] TEST: [1 3]
See also
--------
RepeatedKFold: Repeats K-Fold n times.
"""
def __init__(self, n_splits=5, n_repeats=10, random_state=None):
super(RepeatedStratifiedKFold, self).__init__(
StratifiedKFold, n_repeats, random_state, n_splits=n_splits)


class BaseShuffleSplit(with_metaclass(ABCMeta)):
"""Base class for ShuffleSplit and StratifiedShuffleSplit"""

Expand Down
72 changes: 72 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
from sklearn.model_selection import check_cv
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RepeatedKFold
from sklearn.model_selection import RepeatedStratifiedKFold

from sklearn.linear_model import Ridge

Expand Down Expand Up @@ -804,6 +806,76 @@ def test_leave_one_p_group_out_error_on_fewer_number_of_groups():
LeavePGroupsOut(n_groups=3).split(X, y, groups))


def test_repeated_cv_value_errors():
# n_repeats is not integer or <= 1
for cv in (RepeatedKFold, RepeatedStratifiedKFold):
assert_raises(ValueError, cv, n_repeats=1)
assert_raises(ValueError, cv, n_repeats=1.5)


def test_repeated_kfold_determinstic_split():
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
random_state = 258173307
rkf = RepeatedKFold(
n_splits=2,
n_repeats=2,
random_state=random_state)

# split should produce same and deterministic splits on
# each call
for _ in range(3):
splits = rkf.split(X)
train, test = next(splits)
assert_array_equal(train, [2, 4])
assert_array_equal(test, [0, 1, 3])

train, test = next(splits)
assert_array_equal(train, [0, 1, 3])
assert_array_equal(test, [2, 4])

train, test = next(splits)
assert_array_equal(train, [0, 1])
assert_array_equal(test, [2, 3, 4])

train, test = next(splits)
assert_array_equal(train, [2, 3, 4])
assert_array_equal(test, [0, 1])

assert_raises(StopIteration, next, splits)


def test_repeated_stratified_kfold_determinstic_split():
X = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
y = [1, 1, 1, 0, 0]
random_state = 1944695409
rskf = RepeatedStratifiedKFold(
n_splits=2,
n_repeats=2,
random_state=random_state)

# split should produce same and deterministic splits on
# each call
for _ in range(3):
splits = rskf.split(X, y)
train, test = next(splits)
assert_array_equal(train, [1, 4])
assert_array_equal(test, [0, 2, 3])

train, test = next(splits)
assert_array_equal(train, [0, 2, 3])
assert_array_equal(test, [1, 4])

train, test = next(splits)
assert_array_equal(train, [2, 3])
assert_array_equal(test, [0, 1, 4])

train, test = next(splits)
assert_array_equal(train, [0, 1, 4])
assert_array_equal(test, [2, 3])

assert_raises(StopIteration, next, splits)


def test_train_test_split_errors():
assert_raises(ValueError, train_test_split)
assert_raises(ValueError, train_test_split, range(3), train_size=1.1)
Expand Down

0 comments on commit f149d30

Please sign in to comment.