In [2]:
import pickle
import pandas as pd
import numpy as np
from sklearn.pipeline import make_pipeline
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import QuantileTransformer
from statsmodels.tsa.holtwinters import ExponentialSmoothing
import matplotlib.pyplot as plt
from typing import Any, Dict, List

In [3]:
df_spatiotemporal = capacity_factors_daily_2000to2015 = pd.read_hdf(
    path_or_buf='../data/05_model_input/df_spatiotemporal.hdf', 
    key='df_spatiotemporal'
)

with open('../data/05_model_input/cv_splits_dict.pkl/2020-08-28T04.48.19.250Z/cv_splits_dict.pkl', 'rb') as pkl_file:
    cv_splits_dict = pickle.load(pkl_file)

## Non-Essential

In [4]:
df_spatiotemporal['temporal'].columns.get_level_values('district')

Index(['DE111', 'DE114', 'DE115', 'DE116', 'DE118', 'DE119', 'DE11A', 'DE11B',
       'DE11C', 'DE11D',
       ...
       'DEG0E', 'DEG0F', 'DEG0G', 'DEG0I', 'DEG0J', 'DEG0K', 'DEG0L', 'DEG0M',
       'DEG0N', 'DEG0P'],
      dtype='object', name='district', length=292)

In [5]:
df_spatiotemporal['temporal']['DEF0C']

var,power
2013-01-01,0.298824
2013-01-02,0.277085
2013-01-03,0.561925
2013-01-04,0.603515
2013-01-05,0.139887
...,...
2015-12-27,0.652169
2015-12-28,0.313875
2015-12-29,0.591666
2015-12-30,0.460236


In [6]:
df_spatiotemporal.index

DatetimeIndex(['2013-01-01', '2013-01-02', '2013-01-03', '2013-01-04',
               '2013-01-05', '2013-01-06', '2013-01-07', '2013-01-08',
               '2013-01-09', '2013-01-10',
               ...
               '2015-12-22', '2015-12-23', '2015-12-24', '2015-12-25',
               '2015-12-26', '2015-12-27', '2015-12-28', '2015-12-29',
               '2015-12-30', '2015-12-31'],
              dtype='datetime64[ns]', length=1095, freq='D')

In [7]:
df_spatiotemporal.loc['2013-01-01': '2015-12-22']

data_type,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,...,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal
district,DE111,DE111,DE114,DE114,DE115,DE115,DE116,DE116,DE118,DE118,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,lat,lon,lat,lon,lat,lon,lat,lon,lat,lon,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.187667,0.363388,0.393872,0.482197,0.551744,0.435241,0.372426,0.464629,0.393286,0.497463
2013-01-02,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.073910,0.133930,0.226143,0.205528,0.284261,0.219417,0.215123,0.272299,0.203758,0.248020
2013-01-03,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.107586,0.297686,0.451200,0.379433,0.536884,0.463397,0.424858,0.528489,0.426369,0.502846
2013-01-04,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.124282,0.342673,0.427804,0.447503,0.505166,0.311061,0.373373,0.438596,0.329357,0.347780
2013-01-05,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.294815,0.357460,0.321136,0.415867,0.361462,0.349354,0.301562,0.275772,0.331298,0.337855
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-12-18,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.050321,0.099918,0.146184,0.155744,0.277486,0.190390,0.248554,0.277367,0.191652,0.235058
2015-12-19,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.010416,0.044059,0.090099,0.092734,0.209859,0.167729,0.238844,0.305741,0.125329,0.143150
2015-12-20,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.053707,0.169616,0.117497,0.264875,0.397657,0.484214,0.514653,0.618273,0.185191,0.228600
2015-12-21,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.273656,9.418495,...,0.241068,0.381588,0.406004,0.470451,0.558143,0.548531,0.532494,0.553334,0.487775,0.547471


In [8]:
train = slice('2013-01-01', '2015-12-22')
df_spatiotemporal[train]

data_type,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,spatial,...,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal,temporal
district,DE111,DE111,DE114,DE114,DE115,DE115,DE116,DE116,DE118,DE118,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,lat,lon,lat,lon,lat,lon,lat,lon,lat,lon,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.187667,0.363388,0.393872,0.482197,0.551744,0.435241,0.372426,0.464629,0.393286,0.497463
2013-01-02,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.073910,0.133930,0.226143,0.205528,0.284261,0.219417,0.215123,0.272299,0.203758,0.248020
2013-01-03,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.107586,0.297686,0.451200,0.379433,0.536884,0.463397,0.424858,0.528489,0.426369,0.502846
2013-01-04,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.124282,0.342673,0.427804,0.447503,0.505166,0.311061,0.373373,0.438596,0.329357,0.347780
2013-01-05,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.359772,9.441498,...,0.294815,0.357460,0.321136,0.415867,0.361462,0.349354,0.301562,0.275772,0.331298,0.337855
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-12-18,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.050321,0.099918,0.146184,0.155744,0.277486,0.190390,0.248554,0.277367,0.191652,0.235058
2015-12-19,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.010416,0.044059,0.090099,0.092734,0.209859,0.167729,0.238844,0.305741,0.125329,0.143150
2015-12-20,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.271255,9.420741,...,0.053707,0.169616,0.117497,0.264875,0.397657,0.484214,0.514653,0.618273,0.185191,0.228600
2015-12-21,48.831018,9.097432,48.625037,9.807222,48.974292,9.171993,48.893531,9.658886,49.273656,9.418495,...,0.241068,0.381588,0.406004,0.470451,0.558143,0.548531,0.532494,0.553334,0.487775,0.547471


## Core

In [9]:
def split_data(df: pd.DataFrame, modeling):
    train = slice(
        modeling['train_window']['start'],
        modeling['train_window']['end']
    )
    
    test = slice(
        modeling['test_window']['start'],
        modeling['test_window']['end']
    )
    
    return {
        'df_train': df[train],
        'df_test': df[test]
    }

In [51]:
class MakeStrictlyPositive(TransformerMixin, BaseEstimator):
    '''Add constant to variable so that it only assumes positive values.'''

    def __init__(self):
        pass

    def fit(self, X, y=None):
        self.offset_ = X.min(axis=0)
        return self 
    
    def transform(self, X, y=None):
        return X + abs(self.offset_) + 1e-08
    
    def inverse_transform(self, X, y=None):
        return X - abs(self.offset_) - 1e-08

In [41]:
modeling = {
    'train_window': {
        'start': '2013-01-01',
        'end':'2015-06-22',
    },
    'test_window': {
        'start': '2015-06-23',
        'end':'2015-06-29',
    },
    'preprocessing': [
        'get_quantile_equivalent_normal_dist',
        'make_strictly_positive',
    ],
    'inference': {
          'approach': 'HW-ES',
          'mode': 'districtwise',
          'trend': 'additive',
          'damped_trend': True,
          'seasonal': 'multiplicative',
          'seasonal_periods': 365,
    },
    'target_timeseries': ['DEF0C', 'DE111']
}


REGISTERED_TRANSFORMERS = {
    'get_quantile_equivalent_normal_dist': QuantileTransformer(
                                                output_distribution='normal', 
                                                random_state=0,
                                            ),
    'make_strictly_positive': MakeStrictlyPositive(),
}

In [42]:
preprocessing_pipeline = make_pipeline(
    QuantileTransformer(
        output_distribution='normal', 
        random_state=0,
    ),
    MakeStrictlyPositive(),
)

In [43]:
preprocessing_pipeline = make_pipeline(
    *[ TRANSFORMERS[ step ] for step in modeling['preprocessing'] ]
)

In [44]:
train_test_split = split_data(df_spatiotemporal, modeling)

In [72]:
df_train = train_test_split['df_train']

df_train['temporal'].head()

district,DE111,DE114,DE115,DE116,DE118,DE119,DE11A,DE11B,DE11C,DE11D,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,power,power,power,power,power,power,power,power,power,power,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,0.178783,0.269458,0.35113,0.184231,0.357407,0.410291,0.362935,0.344386,0.28414,0.360815,...,0.187667,0.363388,0.393872,0.482197,0.551744,0.435241,0.372426,0.464629,0.393286,0.497463
2013-01-02,0.030363,0.063571,0.103089,0.030433,0.07874,0.108224,0.105199,0.093845,0.083709,0.093448,...,0.07391,0.13393,0.226143,0.205528,0.284261,0.219417,0.215123,0.272299,0.203758,0.24802
2013-01-03,0.041567,0.229298,0.182997,0.090303,0.46015,0.516133,0.410587,0.371699,0.320095,0.357628,...,0.107586,0.297686,0.4512,0.379433,0.536884,0.463397,0.424858,0.528489,0.426369,0.502846
2013-01-04,0.128148,0.20583,0.347322,0.144561,0.264061,0.31902,0.357322,0.220658,0.295479,0.364025,...,0.124282,0.342673,0.427804,0.447503,0.505166,0.311061,0.373373,0.438596,0.329357,0.34778
2013-01-05,0.126854,0.346798,0.274581,0.188337,0.212976,0.22108,0.343575,0.272172,0.459154,0.448515,...,0.294815,0.35746,0.321136,0.415867,0.361462,0.349354,0.301562,0.275772,0.331298,0.337855


In [69]:
df_train_preprocessed = df_train.copy(deep=True)

In [73]:
tss_scaled = pd.DataFrame(
    index=df_train_preprocessed['temporal'].index,
    columns=df_train_preprocessed['temporal'].columns,
    data=preprocessing_pipeline.transform(
        df_train['temporal']
    )
)

tss_scaled

district,DE111,DE114,DE115,DE116,DE118,DE119,DE11A,DE11B,DE11C,DE11D,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,power,power,power,power,power,power,power,power,power,power,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,6.727848,6.332357,6.693034,6.613113,6.590818,6.636133,6.548192,6.636133,6.266206,6.501372,...,6.386830,6.824102,6.598172,6.867197,6.793656,6.651914,6.501372,6.676286,6.598172,6.693034
2013-01-02,5.282803,5.199338,5.629049,5.179884,5.556967,5.489596,5.420670,5.524562,5.191001,5.307930,...,5.672127,5.918106,6.150615,6.034640,6.116225,5.984476,6.042544,6.120464,5.984476,6.034640
2013-01-03,5.524562,6.227721,6.095271,6.018986,6.856132,6.913633,6.684608,6.701568,6.370123,6.494913,...,5.950877,6.555127,6.783829,6.534515,6.764620,6.764620,6.727848,6.845267,6.701568,6.727848
2013-01-04,6.415438,6.107795,6.676286,6.445076,6.386830,6.359163,6.534515,6.218335,6.301226,6.507886,...,6.054500,6.736847,6.701568,6.774152,6.676286,6.271129,6.514456,6.527769,6.433091,6.348342
2013-01-05,6.398156,6.590818,6.439061,6.643979,6.195255,6.054500,6.507886,6.386830,6.736847,6.727848,...,6.845267,6.783829,6.421280,6.620703,6.332357,6.381225,6.321859,6.133282,6.439061,6.296141
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-06-18,6.091128,5.925321,5.965703,5.712927,5.622961,5.900230,5.903787,5.784024,5.903787,5.995873,...,5.096334,5.807263,5.641272,5.613859,5.647409,5.518708,5.545142,5.604791,5.571819,5.580770
2015-06-19,5.854761,5.598764,5.412134,5.571819,5.026180,5.352781,5.509947,5.302341,5.709759,5.735252,...,4.909079,5.764355,5.722463,5.626003,5.521634,4.997898,5.349970,5.321918,5.321918,5.288382
2015-06-20,5.316320,5.224351,4.992223,5.319119,4.992223,5.115872,5.324719,5.182663,5.277227,5.372495,...,5.274440,5.635152,5.638210,5.533365,5.536304,5.101920,5.389446,5.358407,5.252163,5.210454
2015-06-21,5.302341,4.986541,5.003567,5.129807,5.204896,5.554006,5.372495,5.319119,5.079556,5.452111,...,4.894556,5.082355,5.174324,5.014885,5.099127,5.475133,5.333126,5.179884,4.983697,5.096334


In [71]:
df_train_preprocessed['temporal'].update(tss_scaled)

df_train_preprocessed['temporal']

district,DE111,DE114,DE115,DE116,DE118,DE119,DE11A,DE11B,DE11C,DE11D,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,power,power,power,power,power,power,power,power,power,power,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,6.727848,6.332357,6.693034,6.613113,6.590818,6.636133,6.548192,6.636133,6.266206,6.501372,...,6.386830,6.824102,6.598172,6.867197,6.793656,6.651914,6.501372,6.676286,6.598172,6.693034
2013-01-02,5.282803,5.199338,5.629049,5.179884,5.556967,5.489596,5.420670,5.524562,5.191001,5.307930,...,5.672127,5.918106,6.150615,6.034640,6.116225,5.984476,6.042544,6.120464,5.984476,6.034640
2013-01-03,5.524562,6.227721,6.095271,6.018986,6.856132,6.913633,6.684608,6.701568,6.370123,6.494913,...,5.950877,6.555127,6.783829,6.534515,6.764620,6.764620,6.727848,6.845267,6.701568,6.727848
2013-01-04,6.415438,6.107795,6.676286,6.445076,6.386830,6.359163,6.534515,6.218335,6.301226,6.507886,...,6.054500,6.736847,6.701568,6.774152,6.676286,6.271129,6.514456,6.527769,6.433091,6.348342
2013-01-05,6.398156,6.590818,6.439061,6.643979,6.195255,6.054500,6.507886,6.386830,6.736847,6.727848,...,6.845267,6.783829,6.421280,6.620703,6.332357,6.381225,6.321859,6.133282,6.439061,6.296141
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-06-18,6.091128,5.925321,5.965703,5.712927,5.622961,5.900230,5.903787,5.784024,5.903787,5.995873,...,5.096334,5.807263,5.641272,5.613859,5.647409,5.518708,5.545142,5.604791,5.571819,5.580770
2015-06-19,5.854761,5.598764,5.412134,5.571819,5.026180,5.352781,5.509947,5.302341,5.709759,5.735252,...,4.909079,5.764355,5.722463,5.626003,5.521634,4.997898,5.349970,5.321918,5.321918,5.288382
2015-06-20,5.316320,5.224351,4.992223,5.319119,4.992223,5.115872,5.324719,5.182663,5.277227,5.372495,...,5.274440,5.635152,5.638210,5.533365,5.536304,5.101920,5.389446,5.358407,5.252163,5.210454
2015-06-21,5.302341,4.986541,5.003567,5.129807,5.204896,5.554006,5.372495,5.319119,5.079556,5.452111,...,4.894556,5.082355,5.174324,5.014885,5.099127,5.475133,5.333126,5.179884,4.983697,5.096334


In [53]:
preprocessing_pipeline = preprocessing_pipeline.fit(
    df_train_preprocessed['temporal']
)

df_train_preprocessed['temporal'].update(
    preprocessing_pipeline.transform(
        df_train_preprocessed['temporal']
    )
)

df_train_preprocessed['temporal']

district,DE111,DE114,DE115,DE116,DE118,DE119,DE11A,DE11B,DE11C,DE11D,...,DEG0E,DEG0F,DEG0G,DEG0I,DEG0J,DEG0K,DEG0L,DEG0M,DEG0N,DEG0P
var,power,power,power,power,power,power,power,power,power,power,...,power,power,power,power,power,power,power,power,power,power
2013-01-01,0.178783,0.269458,0.351130,0.184231,0.357407,0.410291,0.362935,0.344386,0.284140,0.360815,...,0.187667,0.363388,0.393872,0.482197,0.551744,0.435241,0.372426,0.464629,0.393286,0.497463
2013-01-02,0.030363,0.063571,0.103089,0.030433,0.078740,0.108224,0.105199,0.093845,0.083709,0.093448,...,0.073910,0.133930,0.226143,0.205528,0.284261,0.219417,0.215123,0.272299,0.203758,0.248020
2013-01-03,0.041567,0.229298,0.182997,0.090303,0.460150,0.516133,0.410587,0.371699,0.320095,0.357628,...,0.107586,0.297686,0.451200,0.379433,0.536884,0.463397,0.424858,0.528489,0.426369,0.502846
2013-01-04,0.128148,0.205830,0.347322,0.144561,0.264061,0.319020,0.357322,0.220658,0.295479,0.364025,...,0.124282,0.342673,0.427804,0.447503,0.505166,0.311061,0.373373,0.438596,0.329357,0.347780
2013-01-05,0.126854,0.346798,0.274581,0.188337,0.212976,0.221080,0.343575,0.272172,0.459154,0.448515,...,0.294815,0.357460,0.321136,0.415867,0.361462,0.349354,0.301562,0.275772,0.331298,0.337855
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2015-06-18,0.083842,0.163687,0.163932,0.060402,0.091288,0.177450,0.184752,0.128957,0.193386,0.208285,...,0.035522,0.111379,0.118503,0.126316,0.156242,0.130795,0.122257,0.149732,0.123431,0.141750
2015-06-19,0.063776,0.107546,0.080187,0.048765,0.036200,0.092505,0.116975,0.070741,0.153659,0.151818,...,0.027530,0.106492,0.128349,0.128009,0.136787,0.068522,0.093796,0.108388,0.090027,0.094870
2015-06-20,0.031796,0.066823,0.046204,0.036060,0.034260,0.071043,0.096165,0.059447,0.090460,0.099396,...,0.044303,0.089142,0.118392,0.115530,0.139105,0.075821,0.099002,0.112456,0.082387,0.083599
2015-06-21,0.031256,0.048919,0.046900,0.028251,0.047832,0.120044,0.101213,0.072055,0.072594,0.108975,...,0.027350,0.040454,0.060277,0.055955,0.073054,0.123706,0.091807,0.089768,0.054912,0.071655


In [48]:
(df_train_preprocessed['temporal'] == 0).sum().sum()

5242

In [89]:
cv_params = {
    'method': 'expanding window',
    'window_size_first_pass': 365,
    'window_size_last_pass': 540,
    'n_passes': 3,
    'forecasting_window_size': 7,
}

In [90]:
import os
# os.getcwd()
os.chdir('/home/jonasmmiguel/Documents/learning/poli/thesis/wind-stf/src')
import sys
sys.path.append("../") # go to parent dir

from src.wind_stf.pipelines.data_science.nodes import define_cvsplits

In [91]:
a = define_cvsplits(**cv_params)

a

{'pass_1': {'train_idx': [0, 365], 'test_idx': [365, 372]},
 'pass_2': {'train_idx': [0, 452], 'test_idx': [452, 459]},
 'pass_3': {'train_idx': [0, 539], 'test_idx': [539, 546]}}

In [114]:
def define_cvsplits(cv_pars: Dict) -> Dict[str, Any]:  # Dict[str, List[pd.date_range, List[str]]]:
    """
    Example of Cross-Validation Splits Dictionary:

    cv_splits_dict = {
        'pass_1': {
            'train_idx': [0, 365],
            'test_idx': [365, 465],
        }
    }

    :param window_size_first_pass:
    :param window_size_last_pass:
    :param n_passes:
    :param forecasting_window_size:
    :return:
    """
    window_size_first_pass = cv_pars['window_size_first_pass']
    window_size_last_pass = cv_pars['window_size_last_pass']
    n_passes = cv_pars['n_passes']
    forecasting_window_size = cv_pars['forecasting_window_size']

    cv_splits_dict = {}
    window_size_increment = int( (window_size_last_pass - window_size_first_pass) / (n_passes-1) )
    for p in range(n_passes):
        pass_id = 'pass_' + str(p + 1)
        cv_splits_dict[pass_id] = {
                'train_idx': [
                    0,
                    window_size_first_pass + p * window_size_increment
                ],
                'test_idx': [
                    window_size_first_pass + p * window_size_increment,
                    window_size_first_pass + p * window_size_increment + forecasting_window_size,
                ],
        }
    return cv_splits_dict

In [122]:
def _split_train_val(df: pd.DataFrame, cv_splits_dict: dict, pass_id: str):
    train_idx_start = cv_splits_dict[pass_id]['train_idx'][0]
    train_idx_end = cv_splits_dict[pass_id]['train_idx'][1]

    test_idx_start = cv_splits_dict[pass_id]['test_idx'][0]
    test_idx_end = cv_splits_dict[pass_id]['test_idx'][1]

    return {
        'train': df.iloc[train_idx_start:train_idx_end, :],
        'val': df.iloc[test_idx_start:test_idx_end, :],
    }


class ForecastingModel:
    def __init__(self, y_train, modeling):
        self.hyperpars = modeling
        self.y_train = y_train
        
        self.targets_list = self.hyperpars['target_timeseries']
        if self.targets_list == 'all_available':
            self.targets_list = y_train.columns 
    
    def fit(self):
        
        if modeling['mode'] == 'temporal':  # i.e. districtwise
            self.submodels_ = { district: ForecastingModel(y_train['temporal'][district], hyperpars=modeling['inference']).fit() for district in targets_list } 
            
        elif modeling['mode'] == 'spatio-temporal':  # i.e. all districts at once 
            
            if modeling['inference']['approach'] == 'RNN-ES':
                self.model_ = None

            elif modeling['inference']['approach'] == 'GWNet':
                self.model_ = None
        
        else: 
            return NotImplementedError('')
        return self
    
    def predict(self, start, end, transformer):
        y_hat = self.model_.predict(start, end)
        y_hat_unscaled = transformer.inverse_transform(y_hat)
        return y_hat_unscaled


def _train(y_train, modeling):
    model_ = ForecastingModel(y_train[targets_list], modeling).fit()  
    return model_


def cv_train(df_train_preprocessed: pd.DataFrame,
             cv_splits_dict: Dict[str, Any],
             modeling: Dict[str, Any]) -> Dict[str, Any]:

    model = {}
    for pass_id in cv_splits_dict.keys():

        # splitting
        y = _split_train_val(df_train_preprocessed, cv_splits_dict, pass_id)

        # training
        model[pass_id] = {}
        model[pass_id] = _train(modeling, y_train=y['train'])

    return model

In [123]:
model = cv_train(df_train_preprocessed,
                 cv_splits_dict,
                 modeling)

  warn("Optimization failed to converge. Check mle_retvals.",
  warn("Optimization failed to converge. Check mle_retvals.",
  warn("Optimization failed to converge. Check mle_retvals.",


In [125]:
y['train'].info()

<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 903 entries, 2013-01-01 to 2015-06-22
Freq: D
Columns: 876 entries, ('spatial', 'DE111', 'lat') to ('temporal', 'DEG0P', 'power')
dtypes: float64(876)
memory usage: 6.0 MB


In [117]:
model['pass_3'].predict(
    start='2015-06-21',
    end='2015-06-27',
    scaler=preprocessing_pipeline,  # TODO: populate all districts columns, then predict
)

ValueError: operands could not be broadcast together with shapes (7,) (292,) 