Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add easier way to determine whether data splitter is CV #3297

Merged
merged 10 commits into from
Feb 7, 2022
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Enhancements
* Fixes
* Changes
* Added an ``is_cv`` property to the datasplitters used :pr:`3297`
* Documentation Changes
* Testing Changes

Expand Down
3 changes: 2 additions & 1 deletion evalml/automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from collections import namedtuple

import pandas as pd
from sklearn.model_selection import KFold, StratifiedKFold

from evalml.objectives import get_objective
from evalml.pipelines import (
Expand All @@ -14,6 +13,8 @@
TimeSeriesRegressionPipeline,
)
from evalml.preprocessing.data_splitters import (
KFold,
StratifiedKFold,
TimeSeriesSplit,
TrainingValidationSplit,
)
Expand Down
1 change: 1 addition & 0 deletions evalml/preprocessing/data_splitters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .no_split import NoSplit
from .training_validation_split import TrainingValidationSplit
from .time_series_split import TimeSeriesSplit
from .sk_splitters import KFold, StratifiedKFold
9 changes: 9 additions & 0 deletions evalml/preprocessing/data_splitters/no_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def get_n_splits():
"""
return 0

@property
bchen1116 marked this conversation as resolved.
Show resolved Hide resolved
def is_cv(self):
"""Returns whether or not the data splitter is a cross-validation data splitter.

Returns:
bool: If the splitter is a cross-validation data splitter
"""
return False

def split(self, X, y=None):
"""Divide the data into training and testing sets, where the testing set is empty.

Expand Down
28 changes: 28 additions & 0 deletions evalml/preprocessing/data_splitters/sk_splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""SKLearn data splitter wrapper classes."""
from sklearn.model_selection import KFold, StratifiedKFold


class KFold(KFold):
"""Wrapper class for sklearn's KFold splitter."""

@property
def is_cv(self):
"""Returns whether or not the data splitter is a cross-validation data splitter.

Returns:
bool: If the splitter is a cross-validation data splitter
"""
return True


class StratifiedKFold(StratifiedKFold):
"""Wrapper class for sklearn's Stratified KFold splitter."""

@property
def is_cv(self):
"""Returns whether or not the data splitter is a cross-validation data splitter.

Returns:
bool: If the splitter is a cross-validation data splitter
"""
return True
9 changes: 9 additions & 0 deletions evalml/preprocessing/data_splitters/time_series_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ def get_n_splits(self, X=None, y=None, groups=None):
def _check_if_empty(data):
return data is None or data.empty

@property
def is_cv(self):
"""Returns whether or not the data splitter is a cross-validation data splitter.

Returns:
bool: If the splitter is a cross-validation data splitter
"""
return self._splitter.n_splits > 1

def split(self, X, y=None, groups=None):
"""Get the time series splits.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def get_n_splits():
"""
return 1

@property
def is_cv(self):
"""Returns whether or not the data splitter is a cross-validation data splitter.

Returns:
bool: If the splitter is a cross-validation data splitter
"""
return False

def split(self, X, y=None):
"""Divide the data into training and testing sets.

Expand Down
5 changes: 5 additions & 0 deletions evalml/tests/automl_tests/test_automl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,22 @@ def test_make_data_splitter_default(problem_type, large_data):
assert data_splitter.random_seed == 0
assert data_splitter.shuffle
assert data_splitter.test_size == _LARGE_DATA_PERCENT_VALIDATION
assert not data_splitter.is_cv
return

if problem_type == ProblemTypes.REGRESSION:
assert isinstance(data_splitter, KFold)
assert data_splitter.n_splits == 3
assert data_splitter.shuffle
assert data_splitter.random_state == 0
assert data_splitter.is_cv

if problem_type in [ProblemTypes.BINARY, ProblemTypes.MULTICLASS]:
assert isinstance(data_splitter, StratifiedKFold)
assert data_splitter.n_splits == 3
assert data_splitter.shuffle
assert data_splitter.random_state == 0
assert data_splitter.is_cv

if problem_type in [
ProblemTypes.TIME_SERIES_REGRESSION,
Expand All @@ -132,6 +135,7 @@ def test_make_data_splitter_default(problem_type, large_data):
assert data_splitter.max_delay == 7
assert data_splitter.forecast_horizon == 4
assert data_splitter.time_index == "foo"
assert data_splitter.is_cv


@pytest.mark.parametrize(
Expand All @@ -155,6 +159,7 @@ def test_make_data_splitter_parameters(problem_type, expected_data_splitter):
assert data_splitter.n_splits == 5
assert data_splitter.shuffle
assert data_splitter.random_state == random_seed
assert data_splitter.is_cv


def test_make_data_splitter_parameters_time_series():
Expand Down
1 change: 1 addition & 0 deletions evalml/tests/preprocessing_tests/test_no_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

def test_nosplit_nsplits():
assert NoSplit().get_n_splits() == 0
assert not NoSplit().is_cv


def test_nosplit_default():
Expand Down
33 changes: 33 additions & 0 deletions evalml/tests/preprocessing_tests/test_sk_splitters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
import pytest
from sklearn.model_selection import KFold as sk_kfold
from sklearn.model_selection import StratifiedKFold as sk_stratified

from evalml.preprocessing.data_splitters import KFold, StratifiedKFold


@pytest.mark.parametrize(
"sk_splitter,splitter", [[sk_kfold, KFold], [sk_stratified, StratifiedKFold]]
)
@pytest.mark.parametrize("problem_type", ["binary", "multiclass"])
def test_splitters_equal(problem_type, sk_splitter, splitter, X_y_binary, X_y_multi):
parameters = {"shuffle": True, "random_state": 0, "n_splits": 4}
sk_split = splitter(**parameters)
evalml_split = splitter(**parameters)
if problem_type == "binary":
X, y = X_y_binary
else:
X, y = X_y_multi

skt, skv = [], []
evt, evv = [], []

for t, v in sk_split.split(X, y):
skt.append(t)
skv.append(v)
for t, v in evalml_split.split(X, y):
evt.append(t)
evv.append(v)
np.testing.assert_array_equal(skt, evt)
np.testing.assert_array_equal(skv, evv)
assert evalml_split.is_cv
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def test_tvsplit_nsplits():
assert TrainingValidationSplit().get_n_splits() == 1
assert not TrainingValidationSplit().is_cv


def test_tvsplit_default():
Expand Down