In [1]:
# Resolve path when used in a usecase project
import sys
from pathlib import Path

sys.path.insert(0, str(Path("../../").resolve()))

# Splitters

The `modeling` package contains a few useful classes for splitting data on train and test datasets. We'll demonstrate each below. Choose whichever fits your application best.

Each of the classes have the same API defined by `SplitterBase` ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.base_splitter.SplitterBase)). Call `.split` method on data to split on train and test datasets.

## Setup

First, we'll read in our datasets.

In [2]:
import logging
import sys

logging.basicConfig(level=logging.INFO, stream=sys.stdout)

In [3]:
import pandas as pd

from modeling.datasets import get_sample_model_input_data



df = get_sample_model_input_data()
df.head()

INFO:numexpr.utils:Note: NumExpr detected 10 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
INFO:numexpr.utils:NumExpr defaulting to 8 threads.


Unnamed: 0,timestamp,air_flow01,air_flow02,air_flow03,air_flow04,air_flow05,air_flow06,air_flow07,amina_flow,column_level01,...,ore_pulp_flow,ore_pulp_ph,silica_conc,silica_feed,starch_flow,iron_minus_silica,feed_diff_divide_silica,total_column_level,total_air_flow,silica_conc_lagged
0,2017-03-09 23:00:00,251.166672,250.226086,250.178287,295.096,300.0,251.232529,250.208184,578.786678,450.383776,...,398.753368,10.113487,9.894903,16.98,3162.625026,38.22,2.250883,3168.370621,1848.107759,
1,2017-03-10 02:00:00,250.083563,250.174326,250.066843,295.096,300.0,249.992259,250.179793,574.098837,462.428981,...,399.50087,10.032253,8.972384,16.98,3280.25859,38.22,2.250883,3258.210789,1845.592783,1.31
2,2017-03-10 05:00:00,250.055587,250.182704,250.051909,295.096,300.0,250.080709,250.097083,619.925237,549.723694,...,399.903189,9.939564,11.834396,17.12,3199.440463,37.996667,2.219431,3863.737361,1845.563993,1.246667
3,2017-03-10 08:00:00,249.988883,250.047848,250.020237,295.096,300.0,250.159856,250.037881,590.318354,550.111556,...,400.060293,10.074968,12.626763,17.4,3469.33155,37.55,2.158046,3859.372227,1845.350706,1.75
4,2017-03-10 11:00:00,250.260143,250.197557,250.0852,295.096,300.0,250.060176,250.064574,540.756644,550.344274,...,400.101667,10.188462,17.352312,17.4,4297.453393,37.55,2.158046,3849.395119,1845.76365,2.063333


In [4]:
datetime_column = "timestamp"

## Splitters

### Splitter diagrams

![diagram](./_images/_SplitterBase.png)

### `ByFracSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_frac_splitter.ByFracSplitter)) simply uses the functionality of `sklearn.model_selection.train_test_split` to perform the data split. Also, it has an optional `sort` parameter that allows to sort the data by the datetime column (if desired) before splitting. Pass [extra keyword arguments](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) to `__init__` method of `ByFracSplitter` to specify settings.

In [5]:
from modeling import ByFracSplitter

splitter = ByFracSplitter(
    datetime_column=datetime_column,
    sort_date=True,
    test_size=0.2,
)
train_frac, test_frac = splitter.split(df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 1178, length of the test data after splitting is 295.


### `ByDateSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_frac_splitter.ByFracSplitter)) splits the data by the provided `split_date`. Training data will be all samples with `datetime` earlier than provided timestamp. Testing will be all samples later than provided timestamp. `split_date` can take any value accepted by `pd.to_datetime` and `pd.to_datetime` kwargs can be passed into `ByDateSplitter.__init__` method.

In [6]:
from modeling import ByDateSplitter

split_datetime = "2017-08-30 23:00:00"

splitter = ByDateSplitter(
    datetime_column=datetime_column,
    split_datetime=split_datetime,
)
train_data, test_data = splitter.split(df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.by_date_splitter:Splitting by datetime: 2017-08-30 23:00:00
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 1392, length of the test data after splitting is 81.


In [7]:
(
    all(train_data[datetime_column] < split_datetime),
    all(test_data[datetime_column] >= split_datetime),
)

(True, True)

In [8]:
any(test_data[datetime_column] == split_datetime)

True

### `ByIntervalsSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_intervals_splitter.ByIntervalsSplitter)) has splits the data by providing a date ranges to put into each dataset. We can use this in three different ways:

1. Only provide `train_periods`. Samples not in these ranges will be in the test set.
2. Only provide `test_periods`. Conversely, samples not in these ranges will be in the train set.
3. Provide both `train_periods` and `test_periods`. Samples in neither of these ranges will not be returned.

This class is useful for explicitly excluding periods of time we know there was an operational issue with the data that we didn't want to handle in cleaning. We'll demonstrate case 3 below.

Range definitions are inclusive and are checked to ensure they are not overlapping. As before, upper and lower bounds of the date ranges can be anything `pd.to_datetime` accepts.

In [9]:
from modeling import ByIntervalsSplitter

train_intervals = [
    ("2017-07-03 20:00:00", "2017-08-03 20:00:00"), # Exclude a week in August.
    ("2017-08-10 20:00:00", "2017-08-17 20:00:00"),
]
test_intervals = [("2017-09-01 00:00:00", "2017-09-09 23:00:00")]

splitter = ByIntervalsSplitter(
    datetime_column=datetime_column,
    train_intervals=train_intervals,
    test_intervals=test_intervals,
)
train_periods, test_periods = splitter.split(df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 304, length of the test data after splitting is 71.


In [10]:
train_periods.shape, test_periods.shape

((304, 29), (71, 29))

In [11]:
(
    pd.to_datetime("2017-08-08 20:00:00") in train_periods[datetime_column],
    pd.to_datetime("2017-08-08 20:00:00") in test_periods[datetime_column]
)

(False, False)

### `ByLastWindowSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_last_window.ByLastWindowSplitter)) splits the data by splitting the last window of time as the testing and all data before that window as training. For example, if we want the last week of data to be our test set, and everything else to be our training set.

Below, `freq` can be anything accepted by `to_offset` that we can subtract from a `datetime` object. For example, `W`, `Y`, `M`, `D`, `H`, `min`, and `S` are all acceptable frequencies.

In [12]:
from modeling import ByLastWindowSplitter

freq = "2W"  # Two weeks.

splitter = ByLastWindowSplitter(
datetime_column=datetime_column, freq=freq,
)
train_last, test_last = splitter.split(df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.by_last_window:Splitting by datetime: 2017-08-27 23:00:00
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 1369, length of the test data after splitting is 104.


As you can see from the logging, this just calculates the date of the last window and uses `split_by_date`.

In [13]:
train_last.shape, test_last.shape

((1369, 29), (104, 29))

### `BySequentialSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_sequential_splitter.BySequentialSplitter)) takes two frequency arguments `block_freq` and `train_freq`. For each `block_freq`, `train_freq` amount of time will be used for training and the rest of the samples in `block_freq` will be used for testing.

In [14]:
from modeling import BySequentialSplitter

block_freq = "1D"  # For each day...
train_freq = "18H"  # ... use 18 hours for training and the remaining 6 for testing.

splitter = BySequentialSplitter(
    datetime_column=datetime_column,
    block_freq=block_freq,
    train_freq=train_freq,
)
train_sequential, test_sequential = splitter.split(df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 1104, length of the test data after splitting is 369.


In [15]:
train_sequential.head(18 + 1)

Unnamed: 0,timestamp,air_flow01,air_flow02,air_flow03,air_flow04,air_flow05,air_flow06,air_flow07,amina_flow,column_level01,...,ore_pulp_flow,ore_pulp_ph,silica_conc,silica_feed,starch_flow,iron_minus_silica,feed_diff_divide_silica,total_column_level,total_air_flow,silica_conc_lagged
0,2017-03-09 23:00:00,251.166672,250.226086,250.178287,295.096,300.0,251.232529,250.208184,578.786678,450.383776,...,398.753368,10.113487,9.894903,16.98,3162.625026,38.22,2.250883,3168.370621,1848.107759,
1,2017-03-10 02:00:00,250.083563,250.174326,250.066843,295.096,300.0,249.992259,250.179793,574.098837,462.428981,...,399.50087,10.032253,8.972384,16.98,3280.25859,38.22,2.250883,3258.210789,1845.592783,1.31
2,2017-03-10 05:00:00,250.055587,250.182704,250.051909,295.096,300.0,250.080709,250.097083,619.925237,549.723694,...,399.903189,9.939564,11.834396,17.12,3199.440463,37.996667,2.219431,3863.737361,1845.563993,1.246667
3,2017-03-10 08:00:00,249.988883,250.047848,250.020237,295.096,300.0,250.159856,250.037881,590.318354,550.111556,...,400.060293,10.074968,12.626763,17.4,3469.33155,37.55,2.158046,3859.372227,1845.350706,1.75
4,2017-03-10 11:00:00,250.260143,250.197557,250.0852,295.096,300.0,250.060176,250.064574,540.756644,550.344274,...,400.101667,10.188462,17.352312,17.4,4297.453393,37.55,2.158046,3849.395119,1845.76365,2.063333
5,2017-03-10 14:00:00,250.119752,249.85577,250.041583,295.096,300.0,250.053024,250.022956,523.987763,549.869907,...,399.303424,10.343236,20.075338,17.32,4990.956981,37.976667,2.192648,3859.242841,1845.189085,2.386667
8,2017-03-10 23:00:00,250.133769,250.297398,250.060098,295.096,300.0,250.047656,250.003893,346.474348,549.918187,...,400.273437,10.061604,14.23933,14.19,4618.658196,43.356667,3.055438,3705.504719,1845.638813,1.373333
9,2017-03-11 02:00:00,249.849022,250.050174,249.925398,295.096,300.0,250.109602,249.977889,336.002676,550.197283,...,399.79963,10.170165,12.922474,8.25,3120.682148,52.41,6.352727,3788.002292,1845.008085,1.183333
10,2017-03-11 05:00:00,250.130278,250.172069,250.20743,295.096,300.0,250.093452,250.044387,343.316696,553.37225,...,400.214996,9.821706,13.494741,8.493333,2365.553513,51.91,6.111852,3834.263798,1845.743615,3.326667
11,2017-03-11 08:00:00,250.182476,249.908991,250.050211,295.096,300.0,250.125172,249.950078,523.351189,564.57508,...,400.152228,9.617535,14.569071,8.98,2207.344234,50.91,5.669265,3752.557437,1845.312928,4.68


In [16]:
test_sequential.head(6 + 1)

Unnamed: 0,timestamp,air_flow01,air_flow02,air_flow03,air_flow04,air_flow05,air_flow06,air_flow07,amina_flow,column_level01,...,ore_pulp_flow,ore_pulp_ph,silica_conc,silica_feed,starch_flow,iron_minus_silica,feed_diff_divide_silica,total_column_level,total_air_flow,silica_conc_lagged
6,2017-03-10 17:00:00,249.963543,250.093681,250.070113,295.096,300.0,250.009341,250.027998,517.197504,549.93362,...,399.737507,10.344621,19.240095,17.16,4976.461926,38.83,2.262821,3861.450931,1845.260676,1.513333
7,2017-03-10 20:00:00,249.94378,250.169919,250.039863,295.096,300.0,249.995511,250.082337,435.577665,515.753181,...,399.750733,10.066053,11.701512,17.16,4017.272087,38.83,2.262821,3469.260628,1845.327409,2.216667
14,2017-03-11 17:00:00,249.810896,250.223563,250.039122,295.096,300.0,249.896685,250.09005,552.6112,451.154317,...,398.985193,9.974552,12.265194,9.09,3381.36661,50.57,5.563256,3172.039398,1845.156317,2.443333
15,2017-03-11 20:00:00,249.944893,250.123711,250.047381,295.096,300.0,250.008059,250.064146,558.763622,512.218976,...,399.915413,9.878451,13.994933,9.09,3127.612054,50.57,5.563256,3506.238645,1845.284191,2.36
22,2017-03-12 17:00:00,250.213143,250.000272,250.053331,295.096,300.0,250.221459,250.083674,533.323226,501.240696,...,399.819889,9.878118,13.061891,11.2,3891.145269,47.85,4.272321,3451.072502,1845.66788,2.683333
23,2017-03-12 20:00:00,250.097822,250.090957,250.052933,295.096,300.0,250.187331,250.026326,446.119306,623.963726,...,401.03273,9.95817,14.56088,11.2,3942.05288,47.85,4.272321,4125.149356,1845.55137,1.69
30,2017-03-13 17:00:00,250.074976,250.09108,250.136496,295.096,300.0,250.02667,250.029707,423.50758,587.034526,...,399.823439,9.987743,12.833483,8.94,3872.897195,50.01,5.59396,3948.833264,1845.45493,1.676667


### `ByColumnValueSplitter`

This class ([API](../../../../../../docs/build/apidoc/modeling/modeling.splitters.html#modeling.splitters.by_column_value.ByColumnValueSplitter)) splits data based on single column values: sends rows belonging to a specified collection of values to the test piece, and the rest to the train piece.

This splitter is useful when **working with panel data**, e.g.:
    - In oil wells data, send a pre-defined set of wells to test dataset.
    - In retail stores data, send a pre-defined set of stores to test dataset.
    
Its constructor takes 2 arguments:
* `column_name`: To split by.
* `values_for_test`: Which labels from this column to send to a test set.

Let's create a dummy dataset to showcase the application of such `Splitter`.
Imagine we have temperature and humidity observations coming from 5 countries, and e.g. the goal is to predict `temperature` from `humidity`:

In [17]:
import numpy as np

rng = np.random.default_rng(42)
n_rows = 30

data = pd.DataFrame({
    "country": rng.choice(("USA", "Canada", "Germany", "China", "Brazil"), n_rows),
    "temperature": rng.random(n_rows),
    "humidity": rng.random(n_rows),
})

data

Unnamed: 0,country,temperature,humidity
0,USA,0.227239,0.804764
1,China,0.554585,0.387478
2,China,0.063817,0.288328
3,Germany,0.827631,0.682496
4,Germany,0.631664,0.139752
5,Brazil,0.758088,0.199908
6,USA,0.354526,0.007362
7,China,0.970698,0.786924
8,Canada,0.893121,0.664851
9,USA,0.778383,0.705165


Now let's assume that we want to test the model on `USA` and `Germany` data, and use the other counties` data for training it. Here is how we can achieve this:

In [18]:
from modeling import ByColumnValueSplitter

splitter = ByColumnValueSplitter(
    column_name="country",
    values_for_test=["USA", "Germany",],
)
train_by_column_value, test_by_column_value = splitter.split(data)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 30
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 16, length of the test data after splitting is 14.


Let's see what the train data looks like:

In [19]:
train_by_column_value

Unnamed: 0,country,temperature,humidity
1,China,0.554585,0.387478
2,China,0.063817,0.288328
5,Brazil,0.758088,0.199908
7,China,0.970698,0.786924
8,Canada,0.893121,0.664851
11,Brazil,0.466721,0.458916
12,China,0.043804,0.568741
13,China,0.154289,0.139797
14,China,0.683049,0.11453
15,China,0.744762,0.668403


And the test:

In [20]:
test_by_column_value

Unnamed: 0,country,temperature,humidity
0,USA,0.227239,0.804764
3,Germany,0.827631,0.682496
4,Germany,0.631664,0.139752
6,USA,0.354526,0.007362
9,USA,0.778383,0.705165
10,Germany,0.194639,0.780729
16,Germany,0.96751,0.471096
17,USA,0.325825,0.565236
19,Germany,0.469556,0.634718
20,Germany,0.189471,0.553579


## Functional

`functional` subpackage allows you to work with classes listed above in the functional workaround. This might be especially useful when working with pipelines or other orchestration tools that require simple callable objects (e.g. Kedro).

In [21]:
from modeling import create_splitter, split_data

In [22]:
splitter = create_splitter(
    split_method="date",
    splitting_parameters={
        "datetime_column": datetime_column,
        "split_datetime": split_datetime,
    },
)
splitter

ByDateSplitter(datetime_column='timestamp', split_datetime=Timestamp('2017-08-30 23:00:00'), )

In [23]:
help(create_splitter)

Help on function create_splitter in module modeling.splitters._splitters.functional:

create_splitter(split_method: Literal['date', 'frac', 'intervals', 'last_window', 'sequential_window', 'column_value'], splitting_parameters: Dict[str, Any]) -> modeling.splitters._splitters.base_splitter.SplitterBase
    Create ``SplitterBase`` instance from split_method and splitting parameters.
    
    Supported str options for ``split_method``:
        * "date" to initialize ``ByDateSplitter``
        * "frac" to initialize ``ByFracSplitter``
        * "intervals" to initialize ``ByIntervalsSplitter``
        * "last_window" to initialize ``ByLastWindowSplitter``
        * "sequential_window" to initialize ``BySequentialSplitter``
        * "column_value" to initialize ``ByColumnValueSplitter``
    
    Args:
        split_method: method for choosing type of inheritor of ModelBase to initialize
        splitting_parameters: parameters used for splitter initialization.
    
    Notes:
        ``sp

In [24]:
train_data, test_data = split_data(splitter=splitter, data=df)

INFO:modeling.splitters._splitters.base_splitter:Length of data before splitting is 1473
INFO:modeling.splitters._splitters.by_date_splitter:Splitting by datetime: 2017-08-30 23:00:00
INFO:modeling.splitters._splitters.base_splitter:Length of the train data after splitting is 1392, length of the test data after splitting is 81.
