Skip to content

Commit

Permalink
add class GapCrossValidator
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieZ committed Apr 29, 2019
1 parent f8af325 commit 3d8ccef
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sklearn/model_selection/__init__.py
Expand Up @@ -15,6 +15,7 @@
from ._split import PredefinedSplit
from ._split import train_test_split
from ._split import check_cv
from ._split import GapCrossValidator

from ._validation import cross_val_score
from ._validation import cross_val_predict
Expand Down Expand Up @@ -49,6 +50,7 @@
'StratifiedKFold',
'StratifiedShuffleSplit',
'check_cv',
'GapCrossValidator',
'cross_val_predict',
'cross_val_score',
'cross_validate',
Expand Down
124 changes: 123 additions & 1 deletion sklearn/model_selection/_split.py
Expand Up @@ -42,7 +42,8 @@
'StratifiedShuffleSplit',
'PredefinedSplit',
'train_test_split',
'check_cv']
'check_cv',
'GapCrossValidator']


NSPLIT_WARNING = (
Expand Down Expand Up @@ -2155,3 +2156,124 @@ def _build_repr(self):
params[key] = value

return '%s(%s)' % (class_name, _pprint(params, offset=len(class_name)))

class GapCrossValidator(metaclass=ABCMeta):
"""Base class for all gap cross-validators"""

def __init__(self, gap_before=0, gap_after=0):
self.gap_before = gap_before
self.gap_after = gap_after

def split(self, X, y=None, groups=None):
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like, shape (n_samples, n_features)
Training data, where n_samples is the number of samples
and n_features is the number of features.
y : array-like, of length n_samples
The target variable for supervised learning problems.
groups : array-like, with shape (n_samples,), optional
Group labels for the samples used while splitting the dataset into
train/test set.
Yields
------
train : ndarray
The training set indices for that split.
test : ndarray
The testing set indices for that split.
"""
X, y, groups = indexable(X, y, groups)
indices = np.arange(_num_samples(X))
for train_index, test_index in self._iter_indices(X, y, groups):
yield train_index, test_index

# Since subclasses implement any of the following 5 methods,
# none can be abstract.
def _iter_indices(self, X=None, y=None, groups=None):
"""Generates integer indices corresponding to both training sets and
test sets.
By default, delegates to _iter_train_indices(X, y, groups) and
_iter_test_indices(X, y, groups)
"""
for a, b in zip(self._iter_train_indices(X, y, groups),
self._iter_test_indices(X, y, groups)):
yield a, b

def _iter_train_indices(self, X=None, y=None, groups=None):
"""Generates integer indices corresponding to training sets.
By default, delegates to _iter_test_indices(X, y, groups)
"""
return self.__complement_indices(
self._iter_test_indices(X, y, groups), _num_samples(X))

def _iter_test_indices(self, X=None, y=None, groups=None):
"""Generates integer indices corresponding to test sets.
By default, delegates to _iter_test_masks(X, y, groups)
"""
return GapCrossValidator.__marks_to_indices(
self._iter_test_masks(X, y, groups))

def _iter_test_masks(self, X=None, y=None, groups=None):
"""Generates boolean masks corresponding to test sets.
By default, delegates to _iter_train_masks(X, y, groups)
"""
return self.__complement_marks(self._iter_train_masks(X, y, groups))

def _iter_train_masks(self, X=None, y=None, groups=None):
"""Generates boolean masks corresponding to training sets.
By default, delegates to _iter_train_indices(X, y, groups)
"""
return GapCrossValidator.__indices_to_masks(
self._iter_train_indices(X, y, groups), _num_samples(X))

@staticmethod
def __marks_to_indices(marks):
for mark in marks:
index = np.arange(len(mark))
yield index[mark]

@staticmethod
def __indices_to_masks(indices, n_samples):
for index in indices:
mark = np.zeros(n_samples, dtype=np.bool)
mark[index] = True
yield mark

def __complement_marks(self, marks):
before, after = self.gap_before, self.gap_after
for mark in marks:
complement = np.ones(len(mark), dtype=np.bool)
for i, marked in enumerate(mark):
if marked: # then make its neighbourhood False
begin = max(i - before, 0)
end = min(i + after + 1, len(complement))
complement[np.arange(begin, end)] = False
yield complement

def __complement_indices(self, indices, n_samples):
before, after = self.gap_before, self.gap_after
for index in indices:
complement = np.arange(n_samples)
for i in index:
begin = max(i - before, 0)
end = min(i + after + 1, n_samples)
complement = np.setdiff1d(complement, np.arange(begin, end))
yield complement

@abstractmethod
def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator"""

def __repr__(self):
return _build_repr(self)

0 comments on commit 3d8ccef

Please sign in to comment.