In [9]:
import tensorflow as tf
import keras_tuner
import numpy as np


import pandas as pd
import openpyxl # Needed for reading excel
import pathlib

import decomposition
import models
import data
from metrics import smape


In [10]:
keras_tuner.__version__

'1.3.0'

In [11]:
tf.config.list_logical_devices()

[LogicalDevice(name='/device:CPU:0', device_type='CPU'),
 LogicalDevice(name='/device:GPU:0', device_type='GPU')]

In [18]:
cwd = pathlib.Path.cwd()

code_directory = cwd.parents[1]
gonem_directory = code_directory / "notebooks" / "Gonem"

In [19]:
df = data.get_data(directory_path=gonem_directory, product='maize')
df = df.iloc[:-2]
df.describe()

Unnamed: 0_level_0,AVG_TAVG,AVG_TAVG,AVG_TAVG,AVG_TAVG,AVG_TAVG,Corn Price Futures,MAX_TMAX,MAX_TMAX,MAX_TMAX,MAX_TMAX,...,renewable_energy_consumption_perc_of_total,renewable_energy_consumption_perc_of_total,renewable_energy_consumption_perc_of_total,renewable_energy_consumption_perc_of_total,renewable_energy_consumption_perc_of_total,unemployment_total,unemployment_total,unemployment_total,unemployment_total,unemployment_total
PARTNER_Labels,Brazil,France,Germany,Hungary,Ukraine,Global,Brazil,France,Germany,Hungary,...,Brazil,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Hungary,Ukraine
count,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0,...,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0,214.0
mean,238.152301,125.494113,104.143755,116.576205,96.484903,449.422897,345.918892,225.531663,218.804582,238.560202,...,45.976355,12.883762,13.639136,13.302453,4.538575,10.125584,8.96986,5.587944,6.981075,8.363902
std,24.10076,55.456797,65.178828,79.286625,90.200095,151.626084,24.273653,67.153277,81.278888,85.984321,...,1.906558,2.273139,2.851412,2.68874,2.132976,2.28875,0.867082,2.21661,2.708968,1.103329
min,146.692308,20.168095,-29.860742,-50.419892,-98.247057,201.75,282.0,107.622222,36.388889,43.666667,...,41.71,8.52,7.28,7.29,1.27,6.76,7.39,3.14,3.42,6.35
25%,227.991548,77.127932,48.280935,44.350331,16.710027,356.5,331.0,163.656487,140.181174,163.5,...,44.87625,11.155,11.109375,12.801875,2.81,8.205833,8.08625,3.64625,4.070833,7.4775
50%,245.982781,121.548459,102.466712,122.246111,92.398763,390.125,345.0,228.556851,226.650735,243.4,...,46.697083,13.286667,13.97125,13.64,3.49875,9.53625,9.06875,5.01125,7.343333,8.522083
75%,254.746595,177.640994,163.84533,190.244695,183.040152,562.3125,360.425,284.822581,287.334967,316.7,...,47.57,15.31,16.448125,15.377917,7.0825,12.47375,9.789167,7.403125,9.545,9.293125
max,276.134483,231.747995,229.864177,242.571429,239.800437,818.25,401.0,362.382979,371.823529,392.5,...,48.92,15.53,17.17,17.18,7.44,13.7,10.35,11.17,11.17,9.83


In [20]:
label_columns = ['price']
label_columns = df.columns[df.columns.get_level_values(0).isin(label_columns)].tolist()
label_columns

[('price', 'Brazil'),
 ('price', 'France'),
 ('price', 'Germany'),
 ('price', 'Global'),
 ('price', 'Hungary'),
 ('price', 'Ukraine')]

In [21]:
stl = decomposition.STLDecomposer(labels=label_columns, period=12)
log = decomposition.Logger(labels=label_columns)
std = decomposition.Standardizer()

preproc = decomposition.Processor().add(stl).add(log).add(std)

In [22]:
from windower import WindowGenerator

width = 24
label_width = 6
shift = 6

w = WindowGenerator(input_width=width, label_width=label_width, shift=shift, data=df, 
                    # train_begin=0, train_end=.9, val_begin=None, val_end=.96,
                    train_begin=0, train_end=.97, val_begin=None, val_end=None,
                    # train_begin=0, train_end=.5, val_begin=None, val_end=.8,
                    test_begin=None, test_end=1., connect=True, remove_labels=True, label_columns=label_columns)
w.preprocess(preproc)
w

Total window size: 30
Input indices: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
Label indices: [24 25 26 27 28 29]
Label column name(s): [('price', 'Brazil'), ('price', 'France'), ('price', 'Germany'), ('price', 'Global'), ('price', 'Hungary'), ('price', 'Ukraine')]

In [23]:
w.train_df.tail(5)

Unnamed: 0_level_0,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,...,unemployment_total,unemployment_total,unemployment_total,unemployment_total,price,price,price,price,price,price
Unnamed: 0_level_1,Brazil,Brazil,Brazil,France,France,France,Germany,Germany,Germany,Global,...,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Global,Hungary,Ukraine
TIME_PERIOD,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2021-11-01,1.152042,4.087749,0.822301,1.89249,-2.381107,0.321732,1.801032,-0.824824,-0.475178,2.44171,...,-1.339159,-0.938878,-1.122596,1.39341,3.211244,1.103208,0.984585,1.614126,2.35354,0.980424
2021-12-01,1.147873,-0.569034,-0.26402,1.761404,2.749452,-2.406993,1.925765,-0.147694,-1.175928,2.566864,...,-1.339159,-0.938878,-1.122596,1.39341,0.111976,1.183113,0.962737,1.88283,0.948267,1.66926
2022-01-01,1.139729,-0.155443,-0.079601,1.621999,-1.22754,0.879453,2.051231,0.404508,0.008637,2.689858,...,-1.339159,-0.938878,-1.122596,1.39341,0.949587,1.334444,1.553425,1.794491,1.94641,1.251746
2022-02-01,1.126977,0.017135,-0.103347,1.474924,2.955088,-1.669538,2.17774,0.218855,0.445674,2.810762,...,-1.339159,-0.938878,-1.122596,1.39341,1.060388,1.271195,1.721903,2.165919,1.881186,1.750171
2022-03-01,1.109234,-0.281005,-0.409841,1.3208,2.378568,-0.738195,2.305615,1.126411,0.767055,2.929666,...,-1.339159,-0.938878,-1.122596,1.39341,0.155542,1.348502,2.046054,2.268144,1.755316,1.86988


In [24]:
# all(w.train_df == w.val_df)
w.val_df.head(5)

Unnamed: 0_level_0,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,...,unemployment_total,unemployment_total,unemployment_total,unemployment_total,price,price,price,price,price,price
Unnamed: 0_level_1,Brazil,Brazil,Brazil,France,France,France,Germany,Germany,Germany,Global,...,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Global,Hungary,Ukraine
TIME_PERIOD,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2005-01-01,-1.580179,-0.033289,0.168993,-1.931166,0.731859,-0.240543,-2.378319,-0.180142,0.543452,-2.080851,...,-0.148683,2.48147,0.04068,-1.042821,-0.521366,-1.785991,-1.989881,-1.937839,-1.078649,0.184703
2005-02-01,-1.505664,-0.103264,0.185598,-1.898259,0.780536,-0.264521,-2.347243,0.30646,-0.372704,-2.044557,...,-0.153546,2.446966,0.049942,-1.071167,-0.521366,-1.738325,-2.414601,-2.10074,-1.150978,0.012096
2005-03-01,-1.429504,0.201713,-0.019082,-1.865651,-0.147951,0.238452,-2.316005,0.471407,-0.052199,-2.008348,...,-0.158409,2.412463,0.059204,-1.099513,-0.48555,-1.74955,-2.04902,-1.980676,-1.209392,0.311921
2005-04-01,-1.352375,0.385311,-0.185514,-1.83316,-0.300116,-0.0654,-2.284661,0.164613,-0.349635,-1.972104,...,-0.163272,2.377959,0.068465,-1.127859,-0.589783,-1.968201,-2.378616,-2.210023,-1.174462,-0.05831
2005-05-01,-1.275052,0.381132,-0.171745,-1.800612,-0.030081,0.010476,-2.253189,0.176419,-0.075181,-1.935701,...,-0.168135,2.343456,0.077727,-1.156205,-0.440768,-1.762277,-2.119592,-1.93984,-0.978474,0.470985


In [25]:
w.val_df.tail(5)

Unnamed: 0_level_0,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,...,unemployment_total,unemployment_total,unemployment_total,unemployment_total,price,price,price,price,price,price
Unnamed: 0_level_1,Brazil,Brazil,Brazil,France,France,France,Germany,Germany,Germany,Global,...,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Global,Hungary,Ukraine
TIME_PERIOD,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2021-11-01,1.152042,4.087749,0.822301,1.89249,-2.381107,0.321732,1.801032,-0.824824,-0.475178,2.44171,...,-1.339159,-0.938878,-1.122596,1.39341,3.211244,1.103208,0.984585,1.614126,2.35354,0.980424
2021-12-01,1.147873,-0.569034,-0.26402,1.761404,2.749452,-2.406993,1.925765,-0.147694,-1.175928,2.566864,...,-1.339159,-0.938878,-1.122596,1.39341,0.111976,1.183113,0.962737,1.88283,0.948267,1.66926
2022-01-01,1.139729,-0.155443,-0.079601,1.621999,-1.22754,0.879453,2.051231,0.404508,0.008637,2.689858,...,-1.339159,-0.938878,-1.122596,1.39341,0.949587,1.334444,1.553425,1.794491,1.94641,1.251746
2022-02-01,1.126977,0.017135,-0.103347,1.474924,2.955088,-1.669538,2.17774,0.218855,0.445674,2.810762,...,-1.339159,-0.938878,-1.122596,1.39341,1.060388,1.271195,1.721903,2.165919,1.881186,1.750171
2022-03-01,1.109234,-0.281005,-0.409841,1.3208,2.378568,-0.738195,2.305615,1.126411,0.767055,2.929666,...,-1.339159,-0.938878,-1.122596,1.39341,0.155542,1.348502,2.046054,2.268144,1.755316,1.86988


In [26]:
w.test_df.head(5)

Unnamed: 0_level_0,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,...,unemployment_total,unemployment_total,unemployment_total,unemployment_total,price,price,price,price,price,price
Unnamed: 0_level_1,Brazil,Brazil,Brazil,France,France,France,Germany,Germany,Germany,Global,...,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Global,Hungary,Ukraine
TIME_PERIOD,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2019-11-01,-1.356312,1.783123,0.613835,-1.477992,2.364478,0.566794,0.374042,0.125635,-0.593035,-0.972535,...,-1.08628,-0.86237,-1.099751,0.873988,1.984879,-0.297268,0.176812,-0.233763,0.109866,-0.888247
2019-12-01,-1.260415,0.267536,-0.122423,-1.199474,17.319034,1.091934,0.415039,3.843534,-0.712398,-0.837209,...,-1.125184,-0.835368,-1.074127,0.972816,-0.497831,2.896538,1.063024,0.046399,-0.321027,-0.834174
2020-01-01,-1.160154,0.058314,-0.045474,-0.91953,1.771234,0.677278,0.451964,0.162017,-0.645471,-0.701512,...,-1.164089,-0.808365,-1.048502,1.071644,-0.560252,0.083916,0.223349,-0.162847,2.754658,-0.815125
2020-02-01,-1.056111,0.07307,-0.070002,-0.638764,4.250008,-1.421001,0.484864,2.193019,-0.559492,-0.565529,...,-1.178678,-0.819241,-1.054677,1.098458,-0.477709,0.103096,0.789385,-0.095393,0.244221,-0.724015
2020-03-01,-0.948788,0.297082,-0.23247,-0.35805,2.486478,-1.029892,0.514118,1.337957,-0.407893,-0.429376,...,-1.193267,-0.830117,-1.060851,1.125272,-0.429849,0.041517,0.663871,0.132499,0.15547,-0.50304


In [27]:
w.test_df.tail(5)

Unnamed: 0_level_0,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,price_seasonal,price_residual,price_trend,...,unemployment_total,unemployment_total,unemployment_total,unemployment_total,price,price,price,price,price,price
Unnamed: 0_level_1,Brazil,Brazil,Brazil,France,France,France,Germany,Germany,Germany,Global,...,France,Germany,Hungary,Ukraine,Brazil,France,Germany,Global,Hungary,Ukraine
TIME_PERIOD,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2
2022-06-01,0.941419,-0.475776,0.112838,2.588959,-0.313244,-0.60225,5.819521,9.959318,0.842895,4.159954,...,-1.339159,-0.938878,-1.122596,1.39341,0.835662,1.654303,4.493844,3.611875,1.441408,3.622691
2022-07-01,0.885367,-0.668505,0.135618,2.650066,4.235776,0.846882,6.203896,-2.175068,-0.417677,4.361691,...,-1.339159,-0.938878,-1.122596,1.39341,0.639702,2.762692,2.962163,4.164267,1.253016,3.271409
2022-08-01,0.829343,-1.141253,0.110965,2.715711,-0.829444,0.113682,6.585235,-6.290689,-0.638491,4.56352,...,-1.339159,-0.938878,-1.122596,1.39341,-0.134071,1.835296,2.455755,1.795943,1.583087,3.247852
2022-09-01,0.773468,-0.883928,0.175384,2.786118,10.25274,0.705032,6.962552,-19.853076,-1.983182,4.765273,...,-1.339159,-0.938878,-1.122596,1.39341,0.379337,3.547973,-0.90089,2.818962,1.176588,2.535439
2022-10-01,0.717567,-0.057395,-0.269525,2.861351,9.428552,1.413645,7.335331,-9.693487,-0.582022,4.96692,...,-1.339159,-0.938878,-1.122596,1.39341,0.467988,3.60955,2.263762,3.389607,1.269017,2.525034


In [28]:
label_std = decomposition.Standardizer(mean=std.mean[w.label_columns], std=std.std[w.label_columns])
label_log = decomposition.Logger(label_indices=range(len(w.label_columns)))
# label_mms = decomposition.MinMaxScaler(min=mms.min[w.label_columns], max=mms.max[w.label_columns])
postproc = decomposition.Processor().add(label_std).add(label_log)
w.add_label_postprocess(postproc)

In [29]:
for example_inputs, example_labels in w.train.take(1):
    print(f'Inputs shape (batch, time, features): {example_inputs.shape}')
    print(f'Labels shape (batch, time, features): {example_labels.shape}')
    output_features = example_labels.shape[-1]

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089
Inputs shape (batch, time, features): (32, 24, 75)
Labels shape (batch, time, features): (32, 6, 6)
