In [None]:
import yaml
import xarray as xr
import os
import pickle
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
import geopandas as gpd

from h2ox.ai.dataset.dataset_factory import DatasetFactory
from h2ox.ai.dataset.dataset import FcastDataset
from h2ox.ai.dataset.utils import group_consecutive_nans

%load_ext autoreload
%autoreload 2

# test DatasetFactory

In [None]:
cfg = yaml.load(open('./../conf.yaml','r'), Loader=yaml.SafeLoader)

In [None]:
dataset_factory = DatasetFactory(cfg)

In [None]:
# test pytorch dataset build
ptdf = dataset_factory.build_dataset()

# Walkthrough building dataset

In [None]:
# test data build
data = dataset_factory._build_data()

In [None]:
forecast_horizon=14
future_horizon=76
historical_seq_len=60
target_var= ["targets_WATER_VOLUME"]
historic_variables= ["historic_t2m","historic_tp","targets_WATER_VOLUME","doy_cos"] 
forecast_variables= ["forecast_tp", "forecast_t2m","doy_cos"]
future_variables= ["doy_cos"]

In [None]:
def _get_historic_data(
    data: xr.Dataset,
) -> xr.Dataset:

    data_h = xr.concat(
        [
            data[historic_variables].sel({'steps':np.timedelta64(0)}).shift({'date':ii}) 
            for ii in range(historical_seq_len)
        ],
        pd.TimedeltaIndex([timedelta(days=historical_seq_len - ii) for ii in range(historical_seq_len)], name="historic_roll")
    )

    return data_h.to_array().transpose('date','global_sites','variable','historic_roll')


def _get_forecast_data(
    data: xr.Dataset,
) -> xr.Dataset:

    forecast_period = pd.TimedeltaIndex([timedelta(days=ii) for ii in range(1,forecast_horizon+1)])

    return data[forecast_variables].to_array().sel({'steps':forecast_period}).transpose('date','global_sites','variable','steps')


def _get_future_data(
    data: xr.Dataset,
) -> xr.Dataset:

    future_period = pd.TimedeltaIndex([timedelta(days=ii) for ii in range(forecast_horizon+1,future_horizon+1)])

    return data[future_variables].to_array().sel({'steps':future_period}).transpose('date','global_sites','variable','steps')


def _get_target_data(
    data: xr.Dataset,
) -> xr.Dataset:

    data_y = xr.concat(
        [
            data[target_var].sel({'steps':np.timedelta64(0)}).shift({'date':ii}) 
            for ii in range(forecast_horizon+future_horizon+1)
        ],
        pd.TimedeltaIndex([timedelta(days=ii) for ii in range(forecast_horizon+future_horizon+1)], name="target_roll")
    )

    return data_y.to_array().transpose('date','global_sites','variable','target_roll')

In [None]:
def _onehotencode(data_portion, offset_dim):
    ohe = pd.get_dummies(
        data_portion.transpose('date','global_sites',offset_dim,'variable').stack({"date-site":('date','global_sites')})['global_sites'].to_dataframe()
    ).to_xarray()

    return xr.merge([data_portion.to_dataset(dim='variable'), ohe]).stack({"date-site":('date','global_sites')})

In [None]:
historic = _get_historic_data(data).drop('steps')
forecast = _get_forecast_data(data)
future = _get_future_data(data)
target = _get_target_data(data).drop('steps')

In [None]:
historic = _onehotencode(historic, 'historic_roll')
forecast = _onehotencode(forecast, 'steps')
future = _onehotencode(future, 'steps')
target = _onehotencode(target, 'target_roll')