Skip to content

Commit

Permalink
[MRG+1] fix StratifiedShuffleSplit with 2d y (scikit-learn#9044)
Browse files Browse the repository at this point in the history
* regression test and fix for 2d stratified shuffle split

* strengthen non-overlap sss tests

* clarify test and comment

* remove iter from tests, use str instead of hash
  • Loading branch information
vene authored and AishwaryaRK committed Aug 29, 2017
1 parent 269977e commit 6713cd1
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
5 changes: 5 additions & 0 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,6 +1478,11 @@ def _iter_indices(self, X, y, groups=None):
y = check_array(y, ensure_2d=False, dtype=None)
n_train, n_test = _validate_shuffle_split(n_samples, self.test_size,
self.train_size)

if y.ndim == 2:
# for multi-label y, map each distinct row to its string repr:
y = np.array([str(row) for row in y])

classes, y_indices = np.unique(y, return_inverse=True)
n_classes = classes.shape[0]

Expand Down
29 changes: 28 additions & 1 deletion sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,10 +663,37 @@ def test_stratified_shuffle_split_overlap_train_test_bug():
sss = StratifiedShuffleSplit(n_splits=1,
test_size=0.5, random_state=0)

train, test = next(iter(sss.split(X=X, y=y)))
train, test = next(sss.split(X=X, y=y))

# no overlap
assert_array_equal(np.intersect1d(train, test), [])

# complete partition
assert_array_equal(np.union1d(train, test), np.arange(len(y)))


def test_stratified_shuffle_split_multilabel():
# fix for issue 9037
for y in [np.array([[0, 1], [1, 0], [1, 0], [0, 1]]),
np.array([[0, 1], [1, 1], [1, 1], [0, 1]])]:
X = np.ones_like(y)
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
train, test = next(sss.split(X=X, y=y))
y_train = y[train]
y_test = y[test]

# no overlap
assert_array_equal(np.intersect1d(train, test), [])

# complete partition
assert_array_equal(np.union1d(train, test), np.arange(len(y)))

# correct stratification of entire rows
# (by design, here y[:, 0] uniquely determines the entire row of y)
expected_ratio = np.mean(y[:, 0])
assert_equal(expected_ratio, np.mean(y_train[:, 0]))
assert_equal(expected_ratio, np.mean(y_test[:, 0]))


def test_predefinedsplit_with_kfold_split():
# Check that PredefinedSplit can reproduce a split generated by Kfold.
Expand Down

0 comments on commit 6713cd1

Please sign in to comment.