In [71]:
import optuna
from optuna.integration.tensorboard import TensorBoardCallback
from optuna.trial import TrialState

from xgboost import XGBRegressor
from sklearn.model_selection import TimeSeriesSplit, GridSearchCV, RandomizedSearchCV
from sklearn.impute import SimpleImputer
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer, KNNImputer
from sklearn.pipeline import FeatureUnion, make_pipeline, Pipeline
from sklearn.compose import ColumnTransformer, make_column_selector
from sklearn.preprocessing import (
    LabelEncoder,
    StandardScaler,
    OneHotEncoder,
    FunctionTransformer,
)
from sklearn.metrics import r2_score, mean_absolute_error, mean_absolute_percentage_error, mean_squared_error
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn import config_context

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

import pandas as pd
import numpy as np
from plotnine import *
from mizani.formatters import comma_format, percent_format, currency_format
from datetime import datetime, timedelta, date
from tqdm.notebook import tqdm
from skimpy import clean_columns
from IPython.display import clear_output, display
import holidays
from pickle import dump, load
import warnings  

pd.set_option("display.max.columns", 50)

%load_ext blackcellmagic

import pandas as pd
import numpy as np
from pickle import load, dump

pd.set_option("display.max.columns", 50)

%load_ext blackcellmagic

The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic
The blackcellmagic extension is already loaded. To reload it, use:
  %reload_ext blackcellmagic


### Reading in the Final Model and Predictions from Base Models

In [72]:
final_preprocessor = load(open('Models/preprocessor_trainval.pickle', 'rb'))
final_en = load(open('Models/en_trainval.pickle', 'rb'))
final_xg = load(open('Models/xgboost_trainval.pickle', 'rb'))

In [73]:
df_test = pd.read_csv("Preprocessed Data/df_test_preprocessed.csv").query("not datetime.duplicated()")
df_test

Unnamed: 0,datetime,MWh,temperature_fore_ch,temperature_fore_fr,temperature_fore_de,temperature_fore_it,solar_fore_de_mw,solar_fore_it_mw,wind_fore_de_mw,wind_fore_it_mw,CH_AT,CH_DE,CH_FR,CH_IT,AT_CH,DE_CH,FR_CH,IT_CH,weekend,work_hour,year,hour_counter,holiday_name,hour_sin,hour_cos,...,target_lag_143,target_lag_144,target_lag_145,target_lag_146,target_lag_147,target_lag_148,target_lag_149,target_lag_150,target_lag_151,target_lag_152,target_lag_153,target_lag_154,target_lag_155,target_lag_156,target_lag_157,target_lag_158,target_lag_159,target_lag_160,target_lag_161,target_lag_162,target_lag_163,target_lag_164,target_lag_165,target_lag_166,target_lag_167
0,2022-01-01 00:00:00,146.054792,6.77,8.31,10.64,7.25,0.0,0.0,31805.65,1331.48,1200.0,4000.0,1400.0,3158.0,1200.0,800.0,3200.0,1910.0,1,0,2022,0,Neujahrestag,2.588190e-01,0.965926,...,,,,,,,,,,,,,,,,,,,,,,,,,
1,2022-01-01 01:00:00,139.133354,6.42,8.01,10.46,7.09,0.0,0.0,29880.67,1438.15,1200.0,4000.0,1400.0,3213.0,1200.0,800.0,3200.0,1910.0,1,0,2022,1,Neujahrestag,5.000000e-01,0.866025,...,,,,,,,,,,,,,,,,,,,,,,,,,
2,2022-01-01 02:00:00,147.562500,6.08,7.77,10.21,6.94,0.0,0.0,28826.75,1623.80,1200.0,4000.0,1400.0,2824.0,1200.0,800.0,3200.0,1910.0,1,0,2022,2,Neujahrestag,7.071068e-01,0.707107,...,,,,,,,,,,,,,,,,,,,,,,,,,
3,2022-01-01 03:00:00,157.636204,5.68,7.62,10.07,6.85,0.0,0.0,27631.75,1894.75,1200.0,4000.0,1400.0,2678.0,1200.0,800.0,3200.0,1910.0,1,0,2022,3,Neujahrestag,8.660254e-01,0.500000,...,,,,,,,,,,,,,,,,,,,,,,,,,
4,2022-01-01 04:00:00,163.326766,5.32,7.39,9.87,6.65,0.0,0.0,27128.00,2335.05,1200.0,4000.0,1400.0,2629.0,1200.0,800.0,3200.0,1910.0,1,0,2022,4,Neujahrestag,9.659258e-01,0.258819,...,,,,,,,,,,,,,,,,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8759,2022-12-31 19:00:00,67.876028,9.67,13.90,13.91,10.76,0.0,0.0,44107.32,569.68,1200.0,4000.0,1200.0,1507.0,1000.0,800.0,3700.0,1810.0,1,0,2022,8759,none,-8.660254e-01,0.500000,...,129.064727,142.919117,156.773506,143.824683,143.573879,140.957663,129.295384,124.736827,129.098016,114.592201,113.904337,110.687652,111.462452,110.015173,108.929387,125.451378,137.622234,135.428116,132.88221,130.426565,123.749551,149.211205,150.515579,197.461151,185.654299
8760,2022-12-31 20:00:00,72.765318,9.10,13.57,13.79,10.06,0.0,0.0,44512.60,459.45,1200.0,4000.0,1200.0,1507.0,1000.0,800.0,3700.0,1810.0,1,0,2022,8760,none,-7.071068e-01,0.707107,...,129.064727,142.919117,156.773506,143.824683,143.573879,140.957663,129.295384,124.736827,129.098016,114.592201,113.904337,110.687652,111.462452,110.015173,108.929387,125.451378,137.622234,135.428116,132.88221,130.426565,123.749551,149.211205,150.515579,197.461151,185.654299
8761,2022-12-31 21:00:00,81.277633,8.64,13.29,13.64,9.61,0.0,0.0,44946.45,399.35,1200.0,4000.0,1200.0,1507.0,1000.0,800.0,3700.0,1810.0,1,0,2022,8761,none,-5.000000e-01,0.866025,...,129.064727,142.919117,156.773506,143.824683,143.573879,140.957663,129.295384,124.736827,129.098016,114.592201,113.904337,110.687652,111.462452,110.015173,108.929387,125.451378,137.622234,135.428116,132.88221,130.426565,123.749551,149.211205,150.515579,197.461151,185.654299
8762,2022-12-31 22:00:00,95.496046,8.25,13.20,13.61,9.21,0.0,0.0,44938.83,420.78,1200.0,4000.0,1200.0,1459.0,1000.0,800.0,3700.0,1810.0,1,0,2022,8762,none,-2.588190e-01,0.965926,...,129.064727,142.919117,156.773506,143.824683,143.573879,140.957663,129.295384,124.736827,129.098016,114.592201,113.904337,110.687652,111.462452,110.015173,108.929387,125.451378,137.622234,135.428116,132.88221,130.426565,123.749551,149.211205,150.515579,197.461151,185.654299


In [74]:
# Send final dataframe through sklearn pipeline
df_test_preprocessed = pd.DataFrame(
    final_preprocessor.transform(df_test.drop(["MWh", "datetime"], axis=1)),
    columns=final_preprocessor.get_feature_names_out(),
)

df_test_preprocessed



Unnamed: 0,numeric__temperature_fore_ch,numeric__temperature_fore_fr,numeric__temperature_fore_de,numeric__temperature_fore_it,numeric__solar_fore_de_mw,numeric__solar_fore_it_mw,numeric__wind_fore_de_mw,numeric__wind_fore_it_mw,numeric__CH_AT,numeric__CH_DE,numeric__CH_FR,numeric__CH_IT,numeric__AT_CH,numeric__DE_CH,numeric__FR_CH,numeric__IT_CH,numeric__year,numeric__hour_counter,numeric__hour_sin,numeric__hour_cos,numeric__week_hour_sin,numeric__week_hour_cos,numeric__month_sin,numeric__month_cos,numeric__quarter_sin,...,numeric__target_lag_157,numeric__target_lag_158,numeric__target_lag_159,numeric__target_lag_160,numeric__target_lag_161,numeric__target_lag_162,numeric__target_lag_163,numeric__target_lag_164,numeric__target_lag_165,numeric__target_lag_166,numeric__target_lag_167,categorical__weekend_0,categorical__weekend_1,categorical__work_hour_0,categorical__work_hour_1,categorical__holiday_name_Auffahrt,categorical__holiday_name_Karfreitag,categorical__holiday_name_Nationalfeiertag,categorical__holiday_name_Neujahrestag,categorical__holiday_name_Ostermontag,categorical__holiday_name_Ostern,categorical__holiday_name_Pfingsten,categorical__holiday_name_Pfingstmontag,categorical__holiday_name_Weihnachten,categorical__holiday_name_none
0,-0.518332,-0.712849,0.025025,-0.931294,-0.672406,-0.735088,1.826005,-0.591942,1.053063,0.574432,1.413978,0.489999,1.012158,-1.073783,0.936950,1.069218,2.450538,-1.731985,0.365660,1.365624,-1.390144,-0.260789,0.715311,1.225070,1.422261,...,0.000000,-4.242226e-16,0.000000,-4.508746e-16,0.000000,0.000000,0.000000,0.000000,4.030141e-16,-3.919117e-16,0.000000,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,-0.564253,-0.755848,0.001320,-0.950551,-0.672406,-0.735088,1.632348,-0.517654,1.053063,0.574432,1.413978,0.546368,1.012158,-1.073783,0.936950,1.069218,2.450538,-1.731853,0.706745,1.224345,-1.399002,-0.208640,0.715311,1.225070,1.422261,...,0.000000,-4.242226e-16,0.000000,-4.508746e-16,0.000000,0.000000,0.000000,0.000000,4.030141e-16,-3.919117e-16,0.000000,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
2,-0.608862,-0.790246,-0.031604,-0.968605,-0.672406,-0.735088,1.526322,-0.388361,1.053063,0.574432,1.413978,0.147683,1.012158,-1.073783,0.936950,1.069218,2.450538,-1.731722,0.999641,0.999602,-1.405904,-0.156196,0.715311,1.225070,1.422261,...,0.000000,-4.242226e-16,0.000000,-4.508746e-16,0.000000,0.000000,0.000000,0.000000,4.030141e-16,-3.919117e-16,0.000000,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
3,-0.661343,-0.811745,-0.050041,-0.979437,-0.672406,-0.735088,1.406102,-0.199662,1.053063,0.574432,1.413978,-0.001952,1.012158,-1.073783,0.936950,1.069218,2.450538,-1.731590,1.224388,0.706712,-1.410842,-0.103531,0.715311,1.225070,1.422261,...,0.000000,-4.242226e-16,0.000000,-4.508746e-16,0.000000,0.000000,0.000000,0.000000,4.030141e-16,-3.919117e-16,0.000000,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,-0.708577,-0.844710,-0.076380,-1.003509,-0.672406,-0.735088,1.355424,0.106977,1.053063,0.574432,1.413978,-0.052172,1.012158,-1.073783,0.936950,1.069218,2.450538,-1.731458,1.365670,0.365633,-1.413807,-0.050718,0.715311,1.225070,1.422261,...,0.000000,-4.242226e-16,0.000000,-4.508746e-16,0.000000,0.000000,0.000000,0.000000,4.030141e-16,-3.919117e-16,0.000000,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8755,-0.137844,0.088349,0.455668,-0.508836,-0.672406,-0.735088,3.063579,-1.122485,1.053063,0.574432,0.187339,-1.202107,0.440210,-1.073783,2.269662,0.501753,2.450538,-0.578908,-1.225125,0.706712,-1.225354,0.709380,0.006961,1.414212,0.006457,...,0.354314,6.932615e-01,1.011984,1.029418e+00,1.116588,1.007288,0.681557,1.332631,1.224470e+00,2.550092e+00,2.122568,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8756,-0.212629,0.041051,0.439864,-0.593086,-0.672406,-0.735088,3.104351,-1.199253,1.053063,0.574432,0.187339,-1.202107,0.440210,-1.073783,2.269662,0.501753,2.450538,-0.578776,-1.000378,0.999602,-1.198062,0.754687,0.006961,1.414212,0.006457,...,0.354314,6.932615e-01,1.011984,1.029418e+00,1.116588,1.007288,0.681557,1.332631,1.224470e+00,2.550092e+00,2.122568,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8757,-0.272983,0.000919,0.420110,-0.647248,-0.672406,-0.735088,3.147997,-1.241109,1.053063,0.574432,0.187339,-1.202107,0.440210,-1.073783,2.269662,0.501753,2.450538,-0.578645,-0.707482,1.224345,-1.169096,0.798941,0.006961,1.414212,0.006457,...,0.354314,6.932615e-01,1.011984,1.029418e+00,1.116588,1.007288,0.681557,1.332631,1.224470e+00,2.550092e+00,2.122568,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8758,-0.324152,-0.011980,0.416159,-0.695391,-0.672406,-0.735088,3.147230,-1.226184,1.053063,0.574432,0.187339,-1.251302,0.440210,-1.073783,2.269662,0.501753,2.450538,-0.578513,-0.366398,1.365624,-1.138495,0.842081,0.006961,1.414212,0.006457,...,0.354314,6.932615e-01,1.011984,1.029418e+00,1.116588,1.007288,0.681557,1.332631,1.224470e+00,2.550092e+00,2.122568,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0


In [76]:
pd.DataFrame({"preds": final_en.predict(df_test_preprocessed), "datetime": df_test["datetime"]}).to_csv("Predictions/elasticnet_preds.csv", index=False)



In [77]:
pd.DataFrame({"preds": final_xg.predict(df_test_preprocessed), "datetime": df_test["datetime"]}).to_csv("Predictions/xgboost_preds.csv", index=False)

