In [1]:
import os
import pickle

import pandas as pd

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from keras.callbacks import EarlyStopping

from src.data_preparation.data_preprocessing import DataReshaperLSTM
from src.visualization.metrics import PredictionEvaluator, GlobalResults
from src.config import Config

## Load the data

In [2]:
config = Config()
variant_co2 = 'co2'

#lstm
train_lstm_df = pd.read_csv(os.path.join(config.output_cleaned_lstm, f'{variant_co2}/train.csv')).set_index(["year", config.additional_index]).drop(columns=['country_order'])
test_lstm_df = pd.read_csv(os.path.join(config.output_cleaned_lstm, f'{variant_co2}/test.csv')).set_index(["year", config.additional_index]).drop(columns=['country_order'])

with open(os.path.join(config.output_cleaned_lstm, f'{variant_co2}/data_preprocessor_lstm.pkl'), 'rb') as f:
    data_preprocessor_lstm_for_prediction = pickle.load(f)

with open(os.path.join(config.output_cleaned_lstm, f'{variant_co2}/data_preprocessor_lstm.pkl'), 'rb') as f:
    data_preprocessor_lstm = pickle.load(f)

#lightgbm
train_lightgbm_pred_df = pd.read_csv(os.path.join(config.output_cleaned_hybrid, f'lightgbm_lstm/{variant_co2}/lightgbm_pred_train.csv')).set_index(["year", "country"]).rename_axis(index={"country": config.additional_index})
test_lightgbm_pred_df = pd.read_csv(os.path.join(config.output_cleaned_hybrid, f'lightgbm_lstm/{variant_co2}/lightgbm_pred_test.csv')).set_index(["year", "country"]).rename_axis(index={"country": config.additional_index})
train_lightgbm_pred_preprocessed, test_lightgbm_pred_preprocessed = data_preprocessor_lstm.preprocess_data(train_lightgbm_pred_df, test_lightgbm_pred_df)

In [3]:
train_lightgbm_pred_preprocessed['lightgbm_pred_t-2'] = train_lightgbm_pred_preprocessed['lightgbm_pred']
train_lightgbm_pred_preprocessed['lightgbm_pred_t-1'] = train_lightgbm_pred_preprocessed['lightgbm_pred']
train_lightgbm_pred_preprocessed.drop(columns=['lightgbm_pred'], inplace=True)
train_lightgbm_pred_preprocessed

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,lightgbm_pred_t-1
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1
1961,Arabia Saudyjska,0.007998,0.007998
1962,Arabia Saudyjska,0.008097,0.008097
1963,Arabia Saudyjska,0.008935,0.008935
1964,Arabia Saudyjska,0.009510,0.009510
1965,Arabia Saudyjska,0.009659,0.009659
...,...,...,...
2000,Włochy,0.090754,0.090754
2001,Włochy,0.092026,0.092026
2002,Włochy,0.089925,0.089925
2003,Włochy,0.092097,0.092097


In [4]:
test_lightgbm_pred_preprocessed['lightgbm_pred_t-2'] = test_lightgbm_pred_preprocessed['lightgbm_pred']
test_lightgbm_pred_preprocessed['lightgbm_pred_t-1'] = test_lightgbm_pred_preprocessed['lightgbm_pred']
test_lightgbm_pred_preprocessed.drop(columns=['lightgbm_pred'], inplace=True)
test_lightgbm_pred_preprocessed

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,lightgbm_pred_t-1
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1
1989,Stany Zjednoczone,0.992012,0.992012
1990,Stany Zjednoczone,0.992204,0.992204
1991,Stany Zjednoczone,0.987098,0.987098
1992,Stany Zjednoczone,0.912564,0.912564
1993,Stany Zjednoczone,0.912564,0.912564
...,...,...,...
2019,Szwecja,0.022017,0.022017
2020,Szwecja,0.023969,0.023969
2021,Szwecja,0.023496,0.023496
2022,Szwecja,0.023485,0.023485


In [5]:
#combine train df
train_hybrid_df = pd.merge(train_lightgbm_pred_preprocessed, train_lstm_df, left_index=True, right_index=True)
train_hybrid_df

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,lightgbm_pred_t-1,country_t-2,population_t-2,gdp_t-2,temperature_change_from_co2_t-2,cement_co2_t-2,coal_co2_t-2,flaring_co2_t-2,gas_co2_t-2,...,population_t-1,gdp_t-1,temperature_change_from_co2_t-1,cement_co2_t-1,coal_co2_t-1,flaring_co2_t-1,gas_co2_t-1,land_use_change_co2_t-1,oil_co2_t-1,co2
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
1961,Arabia Saudyjska,0.007998,0.007998,0,0.000194,0.002745,0.000000,0.000087,0.000000,0.000000,0.000000,...,0.000214,0.002734,0.000000,0.000094,0.000000,0.000000,0.000000,0.038131,0.001042,0.014038
1962,Arabia Saudyjska,0.008097,0.008097,0,0.000261,0.003040,0.000000,0.000107,0.000000,0.000000,0.000000,...,0.000284,0.003084,0.000000,0.000108,0.000000,0.000000,0.000000,0.038220,0.001393,0.014426
1963,Arabia Saudyjska,0.008935,0.008935,0,0.000331,0.003426,0.000000,0.000124,0.000000,0.000000,0.000000,...,0.000355,0.003477,0.000000,0.000194,0.000000,0.000000,0.000019,0.038285,0.002431,0.014563
1964,Arabia Saudyjska,0.009510,0.009510,0,0.000403,0.003860,0.000000,0.000221,0.000000,0.000000,0.000019,...,0.000430,0.003799,0.000000,0.000208,0.000000,0.000000,0.000025,0.038396,0.002698,0.014572
1965,Arabia Saudyjska,0.009659,0.009659,0,0.000478,0.004214,0.000000,0.000238,0.000000,0.000000,0.000025,...,0.000508,0.004152,0.000000,0.000270,0.000000,0.000000,0.001053,0.038383,0.002248,0.014198
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2000,Włochy,0.090754,0.090754,44,0.042144,0.202657,0.046358,0.035316,0.012401,0.051699,0.100876,...,0.041863,0.187879,0.045752,0.032498,0.010996,0.041351,0.109807,0.027605,0.102774,0.072296
2001,Włochy,0.092026,0.092026,44,0.042163,0.207264,0.046358,0.037112,0.012192,0.041351,0.109807,...,0.041887,0.196229,0.052288,0.033743,0.011423,0.043386,0.114710,0.026136,0.100887,0.071542
2002,Włochy,0.089925,0.089925,44,0.042187,0.216475,0.052980,0.038533,0.012664,0.043386,0.114710,...,0.041912,0.201132,0.052288,0.034874,0.012033,0.040846,0.114960,0.024196,0.099381,0.071573
2003,Włochy,0.092097,0.092097,44,0.042213,0.221884,0.052980,0.039825,0.013341,0.040846,0.114960,...,0.042000,0.203138,0.052288,0.034651,0.012235,0.037813,0.115605,0.021727,0.101712,0.074712


In [6]:
#combine test df
test_hybrid_df = pd.merge(test_lightgbm_pred_preprocessed, test_lstm_df, left_index=True, right_index=True)
test_hybrid_df

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,lightgbm_pred_t-1,country_t-2,population_t-2,gdp_t-2,temperature_change_from_co2_t-2,cement_co2_t-2,coal_co2_t-2,flaring_co2_t-2,gas_co2_t-2,...,population_t-1,gdp_t-1,temperature_change_from_co2_t-1,cement_co2_t-1,coal_co2_t-1,flaring_co2_t-1,gas_co2_t-1,land_use_change_co2_t-1,oil_co2_t-1,co2
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
1989,Stany Zjednoczone,0.992012,0.992012,34,0.186605,0.980192,1.013245,0.080107,0.386410,0.088209,0.785603,...,0.187127,0.925972,1.013072,0.070727,0.365490,0.101656,0.820841,0.131566,0.887696,0.744042
1990,Stany Zjednoczone,0.992204,0.992204,34,0.188312,1.021423,1.026490,0.080768,0.405226,0.101656,0.820841,...,0.188861,0.958024,1.026144,0.070761,0.369653,0.101017,0.874110,0.130467,0.887150,0.744836
1991,Stany Zjednoczone,0.987098,0.987098,34,0.190056,1.056778,1.039735,0.080807,0.409841,0.101017,0.874110,...,0.190911,0.974791,1.045752,0.071229,0.369909,0.616333,0.870813,0.136453,0.862972,0.729761
1992,Stany Zjednoczone,0.912564,0.912564,34,0.192119,1.075274,1.059603,0.081341,0.410125,0.616333,0.870813,...,0.193265,0.972084,1.058824,0.069638,0.366365,0.603152,0.886989,0.118116,0.840136,0.738090
1993,Stany Zjednoczone,0.912564,0.912564,34,0.194488,1.072287,1.072848,0.079524,0.406197,0.603152,0.886989,...,0.195648,1.006652,1.071895,0.070185,0.370468,0.598723,0.919292,0.101080,0.861291,0.749304
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2019,Szwecja,0.022017,0.022017,35,0.006089,0.052010,0.019868,0.003605,0.001879,0.010654,0.001542,...,0.006096,0.048052,0.019608,0.003419,0.001650,0.010973,0.001605,0.039618,0.011526,0.019703
2020,Szwecja,0.023969,0.023969,35,0.006179,0.053028,0.019868,0.003904,0.001829,0.010973,0.001605,...,0.006175,0.049011,0.019608,0.002887,0.001654,0.007129,0.001515,0.039806,0.011298,0.019193
2021,Szwecja,0.023496,0.023496,35,0.006258,0.054085,0.019868,0.003296,0.001834,0.007129,0.001515,...,0.006232,0.047943,0.019608,0.002706,0.001355,0.005919,0.001227,0.039930,0.010443,0.019574
2022,Szwecja,0.023485,0.023485,35,0.006315,0.052907,0.019868,0.003090,0.001503,0.005919,0.001227,...,0.006279,0.050536,0.019608,0.002678,0.001425,0.000120,0.001618,0.040232,0.011072,0.019473


In [7]:
columns_train_t2 = [col for col in train_hybrid_df.columns if col.endswith('t-2')]
columns_train_t1 = [col for col in train_hybrid_df.columns if col.endswith('t-1')]
co2_train = [col for col in train_hybrid_df.columns if col.endswith('co2')]

train_hybrid_df_sorted = train_hybrid_df[columns_train_t2 + columns_train_t1 + co2_train].copy()
train_hybrid_df_sorted

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,country_t-2,population_t-2,gdp_t-2,temperature_change_from_co2_t-2,cement_co2_t-2,coal_co2_t-2,flaring_co2_t-2,gas_co2_t-2,land_use_change_co2_t-2,...,population_t-1,gdp_t-1,temperature_change_from_co2_t-1,cement_co2_t-1,coal_co2_t-1,flaring_co2_t-1,gas_co2_t-1,land_use_change_co2_t-1,oil_co2_t-1,co2
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
1961,Arabia Saudyjska,0.007998,0,0.000194,0.002745,0.000000,0.000087,0.000000,0.000000,0.000000,0.051685,...,0.000214,0.002734,0.000000,0.000094,0.000000,0.000000,0.000000,0.038131,0.001042,0.014038
1962,Arabia Saudyjska,0.008097,0,0.000261,0.003040,0.000000,0.000107,0.000000,0.000000,0.000000,0.051399,...,0.000284,0.003084,0.000000,0.000108,0.000000,0.000000,0.000000,0.038220,0.001393,0.014426
1963,Arabia Saudyjska,0.008935,0,0.000331,0.003426,0.000000,0.000124,0.000000,0.000000,0.000000,0.051518,...,0.000355,0.003477,0.000000,0.000194,0.000000,0.000000,0.000019,0.038285,0.002431,0.014563
1964,Arabia Saudyjska,0.009510,0,0.000403,0.003860,0.000000,0.000221,0.000000,0.000000,0.000019,0.051607,...,0.000430,0.003799,0.000000,0.000208,0.000000,0.000000,0.000025,0.038396,0.002698,0.014572
1965,Arabia Saudyjska,0.009659,0,0.000478,0.004214,0.000000,0.000238,0.000000,0.000000,0.000025,0.051756,...,0.000508,0.004152,0.000000,0.000270,0.000000,0.000000,0.001053,0.038383,0.002248,0.014198
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2000,Włochy,0.090754,44,0.042144,0.202657,0.046358,0.035316,0.012401,0.051699,0.100876,0.042772,...,0.041863,0.187879,0.045752,0.032498,0.010996,0.041351,0.109807,0.027605,0.102774,0.072296
2001,Włochy,0.092026,44,0.042163,0.207264,0.046358,0.037112,0.012192,0.041351,0.109807,0.037209,...,0.041887,0.196229,0.052288,0.033743,0.011423,0.043386,0.114710,0.026136,0.100887,0.071542
2002,Włochy,0.089925,44,0.042187,0.216475,0.052980,0.038533,0.012664,0.043386,0.114710,0.035230,...,0.041912,0.201132,0.052288,0.034874,0.012033,0.040846,0.114960,0.024196,0.099381,0.071573
2003,Włochy,0.092097,44,0.042213,0.221884,0.052980,0.039825,0.013341,0.040846,0.114960,0.032615,...,0.042000,0.203138,0.052288,0.034651,0.012235,0.037813,0.115605,0.021727,0.101712,0.074712


In [8]:
columns_test_t2 = [col for col in test_hybrid_df.columns if col.endswith('t-2')]
columns_test_t1 = [col for col in test_hybrid_df.columns if col.endswith('t-1')]
co2_test = [col for col in test_hybrid_df.columns if col.endswith('co2')]

test_hybrid_df_sorted = test_hybrid_df[columns_test_t2 + columns_test_t1 + co2_test].copy()
test_hybrid_df_sorted

Unnamed: 0_level_0,Unnamed: 1_level_0,lightgbm_pred_t-2,country_t-2,population_t-2,gdp_t-2,temperature_change_from_co2_t-2,cement_co2_t-2,coal_co2_t-2,flaring_co2_t-2,gas_co2_t-2,land_use_change_co2_t-2,...,population_t-1,gdp_t-1,temperature_change_from_co2_t-1,cement_co2_t-1,coal_co2_t-1,flaring_co2_t-1,gas_co2_t-1,land_use_change_co2_t-1,oil_co2_t-1,co2
year,country_index,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
1989,Stany Zjednoczone,0.992012,34,0.186605,0.980192,1.013245,0.080107,0.386410,0.088209,0.785603,0.170489,...,0.187127,0.925972,1.013072,0.070727,0.365490,0.101656,0.820841,0.131566,0.887696,0.744042
1990,Stany Zjednoczone,0.992204,34,0.188312,1.021423,1.026490,0.080768,0.405226,0.101656,0.820841,0.177344,...,0.188861,0.958024,1.026144,0.070761,0.369653,0.101017,0.874110,0.130467,0.887150,0.744836
1991,Stany Zjednoczone,0.987098,34,0.190056,1.056778,1.039735,0.080807,0.409841,0.101017,0.874110,0.175863,...,0.190911,0.974791,1.045752,0.071229,0.369909,0.616333,0.870813,0.136453,0.862972,0.729761
1992,Stany Zjednoczone,0.912564,34,0.192119,1.075274,1.059603,0.081341,0.410125,0.616333,0.870813,0.183932,...,0.193265,0.972084,1.058824,0.069638,0.366365,0.603152,0.886989,0.118116,0.840136,0.738090
1993,Stany Zjednoczone,0.912564,34,0.194488,1.072287,1.072848,0.079524,0.406197,0.603152,0.886989,0.159215,...,0.195648,1.006652,1.071895,0.070185,0.370468,0.598723,0.919292,0.101080,0.861291,0.749304
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2019,Szwecja,0.022017,35,0.006089,0.052010,0.019868,0.003605,0.001879,0.010654,0.001542,0.053336,...,0.006096,0.048052,0.019608,0.003419,0.001650,0.010973,0.001605,0.039618,0.011526,0.019703
2020,Szwecja,0.023969,35,0.006179,0.053028,0.019868,0.003904,0.001829,0.010973,0.001605,0.053403,...,0.006175,0.049011,0.019608,0.002887,0.001654,0.007129,0.001515,0.039806,0.011298,0.019193
2021,Szwecja,0.023496,35,0.006258,0.054085,0.019868,0.003296,0.001834,0.007129,0.001515,0.053656,...,0.006232,0.047943,0.019608,0.002706,0.001355,0.005919,0.001227,0.039930,0.010443,0.019574
2022,Szwecja,0.023485,35,0.006315,0.052907,0.019868,0.003090,0.001503,0.005919,0.001227,0.053824,...,0.006279,0.050536,0.019608,0.002678,0.001425,0.000120,0.001618,0.040232,0.011072,0.019473


## Reshape the data

In [9]:
data_resherper = DataReshaperLSTM()
x_train, x_test, y_train, y_test = data_resherper.reshape_data(train_hybrid_df_sorted, test_hybrid_df_sorted, num_lags=2)

## Build the model

In [10]:
input_shape = (x_train.shape[1], x_train.shape[2])
output_units = 1

model = Sequential([
    LSTM(48, input_shape=input_shape, return_sequences = True),
    Dropout(0.2),
    LSTM(40),
    Dropout(0.2),
    Dense(output_units, activation='relu')
])

model.compile(optimizer="adam", loss="mean_absolute_error")
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 lstm (LSTM)                 (None, 2, 48)             11520     
                                                                 
 dropout (Dropout)           (None, 2, 48)             0         
                                                                 
 lstm_1 (LSTM)               (None, 40)                14240     
                                                                 
 dropout_1 (Dropout)         (None, 40)                0         
                                                                 
 dense (Dense)               (None, 1)                 41        
                                                                 
Total params: 25,801
Trainable params: 25,801
Non-trainable params: 0
_________________________________________________________________


## Train the model

In [11]:
early_stopping = EarlyStopping(monitor='loss', patience=10, restore_best_weights=True)
history = model.fit(x_train, y_train, epochs=config.epochs, batch_size=config.batch_size, callbacks=[early_stopping])

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100


In [12]:
# Save the model
model.save(os.path.join(config.models_folder, f'{variant_co2}_lightgbm_lstm_model.h5'))

# Evaluate the model on the test set
loss = model.evaluate(x_test, y_test, verbose = 0)

## Predictions

In [13]:
train_predictions = model.predict(x_train)
test_predictions = model.predict(x_test)

inverted_data_predicted_train_y = data_preprocessor_lstm_for_prediction.inverse_transform_data(train_predictions, train_predictions.shape[0], train_hybrid_df.shape[1]-4)
inverted_data_train_y = data_preprocessor_lstm_for_prediction.inverse_transform_data(y_train, train_predictions.shape[0], train_hybrid_df.shape[1]-4)

inverted_data_predicted_test_y = data_preprocessor_lstm_for_prediction.inverse_transform_data(test_predictions, test_predictions.shape[0], test_hybrid_df.shape[1]-4)
inverted_data_test_y = data_preprocessor_lstm_for_prediction.inverse_transform_data(y_test, test_predictions.shape[0], test_hybrid_df.shape[1]-4)



In [14]:
train_df_reset = train_hybrid_df.reset_index()
test_df_reset = test_hybrid_df.reset_index()

co2_predicted_train = inverted_data_predicted_train_y[:, -1]
co2_actual_train = inverted_data_train_y[:, -1]
co2_predicted_test = inverted_data_predicted_test_y[:, -1]
co2_actual_test = inverted_data_test_y[:, -1]

# Create DataFrames for train and test
train_results = pd.DataFrame({
    "country": train_df_reset["country_index"].values,
    "year": train_df_reset["year"].values,
    "co2_predicted": co2_predicted_train,
    "co2_actual": co2_actual_train
})

test_results = pd.DataFrame({
    "country": test_df_reset["country_index"].values,
    "year": test_df_reset["year"].values,
    "co2_predicted": co2_predicted_test,
    "co2_actual": co2_actual_test
})


train_results.to_csv(os.path.join(config.predictions_hybrid_lightgbm_lstm, f'lightgbm_lstm_{variant_co2}_train.csv'))
test_results.to_csv(os.path.join(config.predictions_hybrid_lightgbm_lstm, f'lightgbm_lstm_{variant_co2}_test.csv'))

In [15]:
current_model = "lightgbm_lstm"

global_csv_path = os.path.join(config.predictions, f'combined_results_{variant_co2}.csv')
global_results = GlobalResults(global_csv_path, keys=["country", "year", "set"])

train = train_results.copy()
test = test_results.copy()

if "year" not in train.columns:
    train = train.reset_index()
if "year" not in test.columns:
    test = test.reset_index()

train["set"] = "train"
test["set"] = "test"

train = train.rename(columns={
    "co2_predicted": f"co2_predicted_{current_model}",
    "co2_actual": f"co2_actual_{current_model}"
})
test = test.rename(columns={
    "co2_predicted": f"co2_predicted_{current_model}",
    "co2_actual": f"co2_actual_{current_model}"
})

new_results_df = pd.concat([train, test], axis=0)
new_results_df = new_results_df.sort_values(by=["year", "country", "set"])

global_results.append_results(new_results_df)

## Metrics

In [16]:
evaluator = PredictionEvaluator()
evaluator.evaluate(train_results, test_results, actual_col='co2_actual', predicted_col='co2_predicted', variant = f'lightgbm_lstm_{variant_co2}', model_output_file=config.metrics_hybrid)