### TrainTest splitters operating on (multiple) xarray dataarrays

In [1]:
# Create dummy data
import numpy as np
import pandas as pd
import xarray as xr

# Hide the full data when displaying a dataset in the notebook
xr.set_options(display_expand_data=False) 

n = 50
time_index = pd.date_range("20151020", periods=n, freq="60d")
time_coord = {"time": time_index}
x1 = xr.DataArray(np.random.randn(n), coords=time_coord, name="precursor1")
x2 = xr.DataArray(np.random.randn(n), coords=time_coord, name="precursor2")
y = xr.DataArray(np.random.randn(n), coords=time_coord, name="target")
print(x1)

<xarray.DataArray 'precursor1' (time: 50)>
-0.08806 -1.723 -1.14 -0.0535 0.6887 ... 0.6837 -0.523 -1.067 0.5144 0.1049
Coordinates:
  * time     (time) datetime64[ns] 2015-10-20 2015-12-19 ... 2023-11-07


In [2]:
# Fit to calendar
import s2spy.time

calendar = s2spy.time.AdventCalendar(anchor=(10, 15), freq="180d")
calendar.map_to_data(x1)  # TODO: would be nice to pass in multiple at once.
x1 = s2spy.time.resample(calendar, x1)
x2 = s2spy.time.resample(calendar, x2)
y = s2spy.time.resample(calendar, y)

print(x1)

<xarray.Dataset>
Dimensions:      (anchor_year: 8, i_interval: 2)
Coordinates:
  * anchor_year  (anchor_year) int64 2016 2017 2018 2019 2020 2021 2022 2023
  * i_interval   (i_interval) int64 0 1
    index        (anchor_year, i_interval) int64 0 1 2 3 4 5 ... 11 12 13 14 15
    interval     (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....
    target       (i_interval) bool True False
Data variables:
    precursor1   (anchor_year, i_interval) float64 0.9591 -0.9723 ... 0.5215


In [3]:
# Cross-validation
from sklearn.model_selection import KFold
import s2spy.traintest

kfold = KFold(n_splits=3)
cv = s2spy.traintest.TrainTestSplit(kfold)
for (x1_train, x2_train), (x1_test, x2_test), y_train, y_test in cv.split(x1, x2, y=y):
    print("Train:", x1_train.anchor_year.values)
    print("Test:", x1_test.anchor_year.values)

print(x1_train)

Train: [2019 2020 2021 2022 2023]
Test: [2016 2017 2018]
Train: [2016 2017 2018 2022 2023]
Test: [2019 2020 2021]
Train: [2016 2017 2018 2019 2020 2021]
Test: [2022 2023]
<xarray.Dataset>
Dimensions:      (anchor_year: 6, i_interval: 2)
Coordinates:
  * anchor_year  (anchor_year) int64 2016 2017 2018 2019 2020 2021
  * i_interval   (i_interval) int64 0 1
    index        (anchor_year, i_interval) int64 0 1 2 3 4 5 6 7 8 9 10 11
    interval     (anchor_year, i_interval) object (2016-04-18, 2016-10-15] .....
    target       (i_interval) bool True False
Data variables:
    precursor1   (anchor_year, i_interval) float64 0.9591 -0.9723 ... 0.7427


In [4]:
# Alternative using shorthand notation
x = [x1, x2]
for x_train, x_test, y_train, y_test in cv.split(*x, y=y):
    x1_train, x2_train = x_train
    x1_test, x2_test = x_test
    print("Train:", x1_train.anchor_year.values)
    print("Test:", x1_test.anchor_year.values)

Train: [2019 2020 2021 2022 2023]
Test: [2016 2017 2018]
Train: [2016 2017 2018 2022 2023]
Test: [2019 2020 2021]
Train: [2016 2017 2018 2019 2020 2021]
Test: [2022 2023]
