In [None]:
!export OMP_NUM_THREADS=32

import os 
import omegaconf
from omegaconf import DictConfig, OmegaConf

path_dir_notebook_file = ""                                         # IMPORTANT: Define the directory's path 
                                                                    # where the notebook file is contained.
assert os.path.isdir(path_dir_notebook_file) == True, "Notebook's directory {} does not exist!".format(path_dir_notebook_file)
os.chdir(path_dir_notebook_file)

path_dir_configs = os.path.join(path_dir_notebook_file, "configs")

filename_config_file = "all_NO2_Italian_stations_2020_2020.yaml"    # IMPORTANT: Define the configuration file that has to be performed.
path_config_file = os.path.join(path_dir_configs, filename_config_file)
assert os.path.isfile(path_config_file) == True, "Config file {} does not exist!".format(path_config_file)

cfg = omegaconf.OmegaConf.load(path_config_file)

# ------------------ Libraries ------------------

# Set default GPU to use
gpu_default = cfg['gpu_default']

import warnings
warnings.filterwarnings('ignore')

import torch
import pandas as pd
import numpy as np

import sys
import pickle
import pyro
from pyro.infer.autoguide import AutoDelta, AutoDiagonalNormal, AutoLowRankMultivariateNormal
from datetime import datetime, timedelta
from typing import Callable
import graphviz
import pyro.poutine as poutine
import random
import json
from math import ceil

from pyro.infer import SVI, Trace_ELBO, Predictive, TracePredictive, TraceMeanField_ELBO
from tqdm.auto import trange, tqdm
from IPython.display import display
from IPython.display import clear_output
import gc

# Make PyroModule parameters local (like ordinary torch.nn.Parameters),
# rather than shared by name through Pyro's global parameter store.
# This is highly recommended whenever models can be written without pyro.param().
pyro.settings.set(module_local_params=True)

# Set seed for reproducibility
seed = cfg['seed']
pyro.set_rng_seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)

USE_CUDA = torch.cuda.is_available()
print("Cuda is available: {}".format(USE_CUDA))

if USE_CUDA:
  torch.cuda.set_device(gpu_default)
  print("Default GPU CUDA: {}".format(torch.cuda.current_device()))
  
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from dataset import *
from BYADNN import *
from utils import *
from utils_pyro import *

# ------------------ Dataset information ------------------
local_dir_datasets = os.path.join(path_dir_notebook_file, cfg['dataset_info']['path_dir_datasets'])     # Define the complete path of directory where are contained all datasets
assert os.path.isdir(local_dir_datasets) == True, "Datasets's directory {} does not exist!".format(local_dir_datasets)

gp_method = cfg['dataset_info']['gp_method']
air_poll_selected = cfg['dataset_info']['air_poll_selected']
freq_mode = cfg['dataset_info']['freq_mode']
start_year = cfg['dataset_info']['start_year']
end_year = cfg['dataset_info']['end_year']
type_station = cfg['dataset_info']['type_station']
n_stations = cfg['dataset_info']['n_stations']
co_in_ug_m3 = cfg['dataset_info']['co_in_ug_m3']
round_prediction = cfg['dataset_info']['round_prediction']
remove_imp_target_training = cfg['dataset_info']['remove_imp_target_training']
org_data = cfg['dataset_info']['org_data']
standardization_data = cfg['dataset_info']['standardization_data']

dict_limit_air_pollutants = cfg['dict_limit_air_pollutants']

dataset_name = "org_{}_{}_{}_{}_{}_cod_stations_1x".format(air_poll_selected, start_year, end_year, type_station, n_stations)

# ------------------ Model configuration ------------------
dict_model_config = cfg['model']
model_name = cfg['defaults']['model']

# ------------------ Experiment configuration ------------------
dict_exp_config = cfg['experiment']
batch_size = cfg['experiment']['batch_size']

# ------------------ Experiment's status ------------------
resume_training = False
eval_model = False
dataloader_saved = False
plot_histo_dataset = False

start_date = datetime(start_year, 1, 1, 0, 0)
end_date = datetime(end_year+1, 1, 1, 0, 0)

if freq_mode == "hour":
  delta_time = timedelta(hours=1)
else:
  delta_time = timedelta(days=1)

if gp_method:
  path_dir_dataset = joinpath(local_dir_datasets, "GP")
else:
  path_dir_dataset = joinpath(local_dir_datasets, "D2_K2")

path_dir_dataset = joinpath(path_dir_dataset, dataset_name)
path_txt_cod_stations = joinpath(path_dir_dataset, "cod_stations.txt")
path_dir_dataset = joinpath(path_dir_dataset, round_prediction)

path_idx_missing_values_train = joinpath(path_dir_dataset, "all_idx_missing_values_train.pickle")
path_idx_missing_values_complete = joinpath(path_dir_dataset, "all_complete_idx_missing_values.pickle")
path_idx_train = joinpath(path_dir_dataset, "all_idx_train.pickle")
path_idx_test = joinpath(path_dir_dataset, "all_idx_test.pickle")
path_x_train = joinpath(path_dir_dataset, "all_x_train.pickle")
path_y_train = joinpath(path_dir_dataset, "all_y_train.pickle")
path_x_test = joinpath(path_dir_dataset, "all_x_test.pickle")
path_y_test = joinpath(path_dir_dataset, "all_y_test.pickle")
path_indo_cod_station = joinpath(path_dir_dataset, "info_cod_stations.pickle")

list_cod_stations = []

if os.path.exists(path_txt_cod_stations):
  with open(path_txt_cod_stations) as f:
      list_cod_stations = [line.rstrip() for line in f]

In [None]:
# Load dictionaries of all cod stations selected
info_cod_stations, all_idx_missing_values_train, all_idx_missing_values_complete, \
all_idx_train, all_idx_test, all_x_train, all_y_train, \
all_x_test, all_y_test = load_dictionaries_dataset(     path_indo_cod_station, path_idx_missing_values_train, \
                                                        path_idx_missing_values_complete, path_idx_train, path_idx_test, \
                                                        path_x_train, path_y_train, path_x_test, path_y_test \
                                                )

dict_list_dates = compute_dates(start_year, end_year, freq_mode)   

max_model_order = get_maximum_model_order(info_cod_stations, all_x_train)
print("Max model order: {}".format(max_model_order))

new_x_train_padding, new_x_test_padding, \
dict_dates_train, dict_dates_test =  dataset_padding(   info_cod_stations, all_x_train, all_x_test,
                                                        dict_list_dates, all_idx_train, all_idx_test, max_model_order   )


In [3]:
####################### Dataset loading #######################
path_dir_model_config = joinpath(path_dir_dataset, cfg['experiment']['name_exp'])
os.makedirs(path_dir_model_config, exist_ok=True)

path_dataloaders_dir = joinpath(path_dir_model_config, "dataloaders")
os.makedirs(path_dataloaders_dir, exist_ok=True)

path_dataloader_train = joinpath(path_dataloaders_dir, "train_loader.pth")
path_dict_datasets_train = joinpath(path_dataloaders_dir, "dict_datasets_train.pkl")
path_dict_datasets_test = joinpath(path_dataloaders_dir, "dict_datasets_test.pkl")

path_csv_errors_cod_stations = joinpath(path_dir_model_config, "complete_errors.csv")
path_csv_calibration_errors_cod_stations_train = joinpath(path_dir_model_config, "calibration_errors_train.csv")
path_csv_calibration_errors_cod_stations_test = joinpath(path_dir_model_config, "calibration_errors_test.csv")
path_csv_ence_errors_cod_stations_train = joinpath(path_dir_model_config, "ence_errors_train.csv")
path_csv_ence_errors_cod_stations_test = joinpath(path_dir_model_config, "ence_errors_test.csv")

if os.path.exists(path_csv_calibration_errors_cod_stations_train):
   os.remove(path_csv_calibration_errors_cod_stations_train)

if os.path.exists(path_csv_calibration_errors_cod_stations_test):
   os.remove(path_csv_calibration_errors_cod_stations_test)

if os.path.exists(path_csv_ence_errors_cod_stations_train):
   os.remove(path_csv_ence_errors_cod_stations_train)

if os.path.exists(path_csv_ence_errors_cod_stations_test):
   os.remove(path_csv_ence_errors_cod_stations_test)

path_dir_weights = joinpath(path_dir_model_config, "weights")
os.makedirs(path_dir_weights, exist_ok=True)

path_dir_plots = joinpath(path_dir_model_config, "plots")
os.makedirs(path_dir_plots, exist_ok=True)

path_dir_models = joinpath(path_dir_model_config, "models")
os.makedirs(path_dir_models, exist_ok=True)

path_log_txt = joinpath(path_dir_model_config, "log.txt")

path_cod_stations_processed = joinpath(path_dir_model_config, "cod_stations_processed.txt")
path_csv_ts_cod_stations = joinpath(path_dir_model_config, "ts_cod_stations_imputed.csv")

all_new_x_train = {}
all_new_y_train = {}

list_cod_stations_processed = []

if os.path.exists(path_cod_stations_processed):
  with open(path_cod_stations_processed, 'r') as f:
    list_cod_stations_processed = [line.rstrip('\n') for line in f]

dict_datasets_train = {}
dict_datasets_test = {}
list_len_datasets = []

if resume_training == True or eval_model == True or dataloader_saved == True:
  train_loader = torch.load(path_dataloader_train)

  with open(path_dict_datasets_train, 'rb') as fp:
      dict_datasets_train = pickle.load(fp)

  with open(path_dict_datasets_test, 'rb') as fp:
      dict_datasets_test = pickle.load(fp)
else:
  # Load cod station
  for key in info_cod_stations.keys():
    
    # Clear param store
    pyro.clear_param_store()
      
    current_cod_station = info_cod_stations[key][0]

    if current_cod_station not in list_cod_stations_processed or eval_model == True:
      print("Current cod station: {}".format(current_cod_station))

      model_order = info_cod_stations[key][1]
      train_dataset = CustomDataSet_Train(current_cod_station, key, info_cod_stations, batch_size, new_x_train_padding, all_y_train, all_idx_missing_values_train, 
                                          model_order, max_model_order, dict_dates_train[key], remove_imp_target_training=remove_imp_target_training,
                                          org_data=org_data, standardization_data=standardization_data
      )

      mean_cod_station = train_dataset.mean
      std_cod_station = train_dataset.std

      x_test = new_x_test_padding[key]
      y_test = all_y_test[key]

      test_dataset  = CustomDataSet_Test(current_cod_station, key, info_cod_stations, x_test, y_test, model_order, max_model_order, dict_dates_test[key],
                                        org_data=org_data, standardization_data=standardization_data,
                                        mean=mean_cod_station, std=std_cod_station)

      # Histrogram plot of i-th cod station time series
      if plot_histo_dataset:
        path_dir_plots_histo = joinpath(path_dir_plots, current_cod_station)
        path_dir_plots_histo = joinpath(path_dir_plots_histo, "histo")
        os.makedirs(path_dir_plots_histo, exist_ok=True)

        compute_and_plot_histogram(all_y_train[key], dict_exp_config['n_bins'], train_dataset.min_cod_station, train_dataset.max_cod_station, train_dataset.mean, train_dataset.std, org_data, standardization_data, current_cod_station, True, path_dir_plots_histo)
        compute_and_plot_histogram(all_y_test[key], dict_exp_config['n_bins'], test_dataset.min_cod_station, test_dataset.max_cod_station, test_dataset.mean, test_dataset.std, org_data, standardization_data, current_cod_station, False, path_dir_plots_histo)

      if batch_size > len(train_dataset):
        batch_size = len(train_dataset)
      
      list_len_datasets.append(len(train_dataset))

      dict_datasets_train[key] = train_dataset
      dict_datasets_test[key] = test_dataset

  final_dataset = torch.utils.data.ConcatDataset(dict_datasets_train.values())

  list_indices = []
  last_idx = 0
  for i in range(len(list_len_datasets)):
    
    len_dataset_batch_size = 0
    n_batches_current_dataset = int(list_len_datasets[i] / batch_size)

    if n_batches_current_dataset == 0:
      print("ERROR: The dimension of this dataset is less than the batch size specified")
      exit(-1)

    len_dataset_batch_size = n_batches_current_dataset * batch_size
    
    indices = list(range(last_idx, last_idx + len_dataset_batch_size))
    last_idx += len_dataset_batch_size
    list_indices.append(indices)

  batch_sampler = MyBatchSampler(list_indices, batch_size)
  train_loader = DataLoader(final_dataset, batch_sampler=batch_sampler, pin_memory=True)

  torch.save(train_loader, path_dataloader_train)

  with open(path_dict_datasets_train, "wb") as fp:
      pickle.dump(dict_datasets_train, fp)

  with open(path_dict_datasets_test, "wb") as fp:
      pickle.dump(dict_datasets_test, fp)

In [None]:
##################### Training and Inference module #####################

if resume_training == False:
  file_log = open(path_log_txt, "w")
  file_log.write("Len training set: {}\n".format(len(train_loader.dataset)))
  file_log.write("Input features: {}\n".format(dict_model_config["input_features"]))
  file_log.write("Max model order: {}\n".format(max_model_order))
  file_log.flush()
else:
  file_log = open(path_log_txt, "a")

path_log_txt = joinpath(path_dir_model_config, "log.txt")

# Create a model and a guide, both as (Pyro)Modules.
model: torch.nn.Module = Model( len_dataset=len(train_loader.dataset), input_features=dict_model_config["input_features"], seq_len=max_model_order, 
                                embedding_dim=dict_model_config["embedding_dim"], num_enc_block=dict_model_config["num_enc_block"], num_heads=dict_model_config["num_heads"], 
                                h_enc_layer=dict_model_config["h_enc_layer"], dropout=dict_model_config["dropout"], h1=dict_model_config["h1"], h2=dict_model_config["h2"], 
                                add_positional_encoding=dict_model_config["add_positional_encoding"],
                                add_temporal_embedding=dict_model_config["add_temporal_embedding"],
                                sigma_layer=dict_model_config["sigma_layer"], device=device).to(device)
model.train()

# Variational Posterior
if dict_exp_config["guide"] == "AutoDelta":
  guide: torch.nn.Module = AutoDelta(model)
elif dict_exp_config["guide"] == "AutoDiagonalNormal":
  guide: torch.nn.Module = AutoDiagonalNormal(model)
elif dict_exp_config["guide"] == "AutoLowRankMultivariateNormal":
  guide: torch.nn.Module = AutoLowRankMultivariateNormal(model)
else:
  pass

path_loss_train_txt = joinpath(path_dir_models, "loss_train.txt")
path_loss_test_txt = joinpath(path_dir_models, "loss_test.txt")
path_best_model_txt = joinpath(path_dir_models, "best_model.txt")

if resume_training == False:
  path_architecture_txt = joinpath(path_dir_models, "architecture.txt")
  path_config_model_txt = joinpath(path_dir_models, "config_model.txt")
  path_config_exp_txt = joinpath(path_dir_models, "config_exp.txt")

  file_loss_train = open(path_loss_train_txt, "w")
  file_loss_test = open(path_loss_test_txt, "w")

  with open(path_architecture_txt, 'w') as f:
    f.write(str(model) + "\n")

  with open(path_config_model_txt, 'w') as f:
    f.write(json.dumps(OmegaConf.to_container(dict_model_config)))

  with open(path_config_exp_txt, 'w') as f:
    f.write(json.dumps(OmegaConf.to_container(dict_exp_config)))
  
# Create a loss function as a Module that includes model and guide parameters.
# All Pyro ELBO estimators can be __call__()ed with a model and guide pair as arguments
# to return a loss function Module that takes the same arguments as the model and guide
# and exposes all of their torch.nn.Parameters and pyro.nn.PyroParam parameters.
elbo: Callable[[torch.nn.Module, torch.nn.Module], torch.nn.Module] = Trace_ELBO(vectorize_particles=True)
loss: torch.nn.Module = elbo(model, guide)
loss.to(device=torch.device(device))

# All relevant parameters need to be initialized before an optimizer can be created.
# Since we used AutoNormal guide our parameters have not be initialized yet.
# Therefore we initialize the model and guide by running one mini-batch through the loss.

batch = next(iter(train_loader))
x_batch = batch['x']
y_batch = batch['y']
mask = batch['mask'][0]  
date_batch = batch['date']

if USE_CUDA:
  x_batch = x_batch.cuda()
  y_batch = y_batch.cuda()
  mask = mask.cuda()
  date_batch = [date_batch[0].cuda(), date_batch[1].cuda(), date_batch[2].cuda(), date_batch[3].cuda(), date_batch[4].cuda()]

loss(x_batch, date_batch, mask, y_batch)

# Create a PyTorch optimizer for the parameters of the model and guide in loss_fn.
if dict_exp_config["optimizer"] == "Adam":
  optimizer = torch.optim.Adam(loss.parameters(), lr=dict_exp_config["lr"], weight_decay=dict_exp_config["weight_decay"], 
                               betas=dict_exp_config["betas"])

# --------------------- Training phase ---------------------
best_mse_test = np.inf
best_mse_test_org = np.inf
best_elbo_test = np.inf
best_epoch = -1
best_model_state_dict = None
path_last_model = joinpath(path_dir_weights, "last_model.pt")
path_last_model_params = joinpath(path_dir_weights, "last_model_params.pt")
path_best_model = joinpath(path_dir_weights, "best_model.pt")
path_best_model_params = joinpath(path_dir_weights, "best_model_params.pt")

tot_train_elbo = []
test_elbo = []

if resume_training:
  file_loss_train = open(path_loss_train_txt, "r")
  file_loss_test = open(path_loss_test_txt, "r")
  file_best_model = open(path_best_model_txt, "r")

  lines_file_loss_train = file_loss_train.readlines()
  lines_file_loss_test = file_loss_test.readlines()
  lines_file_best_model = file_best_model.readlines()

  for line in lines_file_loss_train:
    tot_train_elbo.append(float(line))
  
  for line in lines_file_loss_test:
    test_elbo.append(float(line))
  
  file_loss_train.close()
  file_loss_test.close()

  file_loss_train = open(path_loss_train_txt, "a")
  file_loss_test = open(path_loss_test_txt, "a")

  best_epoch = int(lines_file_best_model[0])
  best_mse_test = float(lines_file_best_model[1])
  if org_data == False:
    best_mse_test_org = float(lines_file_best_model[2])
    best_elbo_test = float(lines_file_best_model[3])
  else:
    best_elbo_test = float(lines_file_best_model[2])

# Clear param store
pyro.clear_param_store()

if resume_training:
  model_state_dict, guide, optimizer_state_dict, scheduler_state_dict, epoch = load_checkpoint(path_last_model)
elif eval_model:
  model_state_dict, guide, optimizer_state_dict, scheduler_state_dict, best_epoch = load_checkpoint(path_best_model)

last_epoch_scheduler = -1
if resume_training:
  last_epoch_scheduler = epoch

if resume_training == False:
  if dict_exp_config["lr_scheduler"] is not None:
    if dict_exp_config["lr_scheduler"] == "OneCycleLR":
      scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=dict_exp_config["lr"], steps_per_epoch=dict_exp_config["steps_per_epoch"], epochs=dict_exp_config["num_epochs"], last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "MultiStepLR":
      scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=dict_exp_config["milestones"], gamma=0.1, last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "CosineAnnealingLR":
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=dict_exp_config["steps_per_epoch"], eta_min=0, last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "CosineWarmupScheduler":
      scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=dict_exp_config["warmup"], max_iters=dict_exp_config["num_epochs"], last_epoch=last_epoch_scheduler)
  else:
    scheduler = None

if resume_training or eval_model:

  model.load_state_dict(model_state_dict)
  optimizer.load_state_dict(optimizer_state_dict)

  loss: torch.nn.Module = elbo(model, guide)
  loss.to(device=torch.device(device))

  if dict_exp_config["lr_scheduler"] is not None:
    if dict_exp_config["lr_scheduler"] == "OneCycleLR":
      scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=dict_exp_config["lr"], steps_per_epoch=dict_exp_config["steps_per_epoch"], epochs=dict_exp_config["num_epochs"], last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "MultiStepLR":
      scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=dict_exp_config["milestones"], gamma=0.1, last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "CosineAnnealingLR":
      scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=dict_exp_config["steps_per_epoch"], eta_min=0, last_epoch=last_epoch_scheduler)
    elif dict_exp_config["lr_scheduler"] == "CosineWarmupScheduler":
      scheduler = CosineWarmupScheduler(optimizer=optimizer, warmup=dict_exp_config["warmup"], max_iters=dict_exp_config["num_epochs"], last_epoch=last_epoch_scheduler)
  else:
    scheduler = None

  if scheduler is not None:
    scheduler.load_state_dict(scheduler_state_dict)

  if resume_training:
    pyro.get_param_store().load(path_last_model_params)
  else:
    pyro.get_param_store().load(path_best_model_params)

  model.to(device)

  if eval_model == False:
    bar = trange(epoch, dict_exp_config["num_epochs"])
else:
  bar = trange(dict_exp_config["num_epochs"])

if eval_model == False:
  
  model.train()
  for epoch in bar:

    file_log.write("Epoch: {}\n".format(epoch))
    file_log.flush()

    total_epoch_loss_train = train(loss, train_loader, optimizer, scheduler, use_cuda=USE_CUDA)
    tot_train_elbo.append(total_epoch_loss_train)
    file_loss_train.write(str(total_epoch_loss_train) + "\n")
    file_loss_train.flush()

    if epoch == 0:
      tot_number_of_paramaters = total_number_of_params(loss.parameters())
      print("Number of model parameters: {}".format(tot_number_of_paramaters))
      with open(path_architecture_txt, 'a') as f:
        f.write("Number of model parameters: {}".format(tot_number_of_paramaters))

    bar.set_postfix(tot_loss=f'{total_epoch_loss_train:.3f}')
    
    if epoch % dict_exp_config["eval_epoch"] == 0 and epoch > 0:
      
      print("\n------------- Epoch {} Validation -------------".format(epoch))
      file_log.write("\n------------- Epoch {} Validation -------------\n".format(epoch))
      file_log.flush()

      tot_mse_loss_test = 0.0
      tot_mse_loss_test_org = 0.0
      total_epoch_loss_test = 0.0

      for key in info_cod_stations.keys():
        
        cod_station = dict_datasets_test[key].cod_station
        x_test_tensor = dict_datasets_test[key].x_tensor
        y_test_tensor = dict_datasets_test[key].y_tensor
        mask = dict_datasets_test[key].mask
        mask = torch.from_numpy(mask).to(device)
        dates = dict_datasets_test[key].get_list_info_date()

        min_cod_station = info_cod_stations[key][5]
        max_cod_station = info_cod_stations[key][6]

        mean_cod_station = dict_datasets_test[key].mean
        std_cod_station = dict_datasets_test[key].std

        epoch_loss_test = evaluate(loss, x_test_tensor, y_test_tensor, dates, mask, use_cuda=USE_CUDA)
        total_epoch_loss_test += epoch_loss_test

        predictive = Predictive(model, guide=guide, num_samples=dict_exp_config["num_samples_pred"])
        preds_test = predictive(x=x_test_tensor.to(device), date=dates, mask=mask, y=None)

        if USE_CUDA:
            y_pred_test = preds_test['obs'].T.cpu().detach().numpy().mean(axis=1)
        else:
            y_pred_test = preds_test['obs'].T.detach().numpy().mean(axis=1)

        if standardization_data:
            if std_cod_station > 0.0:
              y_test_tensor = (y_test_tensor * std_cod_station) + mean_cod_station
              y_pred_test = (y_pred_test * std_cod_station) + mean_cod_station
        else:
            y_test_tensor = y_test_tensor + mean_cod_station
            x_test_tensor = y_pred_test + mean_cod_station

        mse_loss_test = mean_squared_error(y_test_tensor, y_pred_test)

        if org_data == False:
          y_test_org = inverse_transform_minmax(y_test_tensor, min_cod_station, max_cod_station)
          y_test_pred_org = inverse_transform_minmax(y_pred_test, min_cod_station, max_cod_station)
          mse_loss_test_org = mean_squared_error(y_test_org, y_test_pred_org)

          print("Current cod station: {} - MSE test: {} - MSE test org: {}".format(cod_station, mse_loss_test, mse_loss_test_org))
          file_log.write("Current cod station: {} - MSE test: {} - MSE test org: {}\n".format(cod_station, mse_loss_test, mse_loss_test_org))
          file_log.flush()

          tot_mse_loss_test_org += mse_loss_test_org
        else:
          print("Current cod station: {} - MSE test: {}".format(cod_station, mse_loss_test))
          file_log.write("Current cod station: {} - MSE test: {}\n".format(cod_station, mse_loss_test))
          file_log.flush()

        tot_mse_loss_test += mse_loss_test
        
      tot_mse_loss_test /= len(info_cod_stations)
      total_epoch_loss_test /= len(info_cod_stations)

      if org_data == False:
        tot_mse_loss_test_org /= len(info_cod_stations)

      test_elbo.append(total_epoch_loss_test)
      file_loss_test.write(str(total_epoch_loss_test) + "\n")
      file_loss_test.flush()

      if best_mse_test > tot_mse_loss_test:
        best_epoch = epoch
        best_mse_test = tot_mse_loss_test

        if org_data == False:
          best_mse_test_org = tot_mse_loss_test_org

        best_elbo_test = total_epoch_loss_test
        best_model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}

        print("\nBest MSE test: {} - epoch: {}".format(best_mse_test, best_epoch))
        if org_data == False:
          print("Best MSE test org: {} - epoch: {}".format(best_mse_test_org, best_epoch))
        print("ELBO loss test: {} - epoch: {}".format(best_elbo_test, best_epoch))

        file_log.write("\nBest MSE test: {} - epoch: {}\n".format(best_mse_test, best_epoch))
        if org_data == False:
          file_log.write("Best MSE test org: {} - epoch: {}\n".format(best_mse_test_org, best_epoch))
        file_log.write("ELBO loss test: {} - epoch: {}\n".format(best_elbo_test, best_epoch))
        file_log.flush()

        if os.path.exists(path_best_model_txt):
          os.remove(path_best_model_txt)    

        file_best_model = open(path_best_model_txt, "w")
        file_best_model.write(str(best_epoch) + "\n")
        file_best_model.write(str(best_mse_test) + "\n")
        if org_data == False:
          file_best_model.write(str(best_mse_test_org) + "\n")
        file_best_model.write(str(best_elbo_test) + "\n")
        file_best_model.flush()
        file_best_model.close()

        # Save best model param
        save_checkpoint(best_model_state_dict, guide, optimizer, scheduler, epoch, path_best_model, path_best_model_params)

      else:
        print("\nCurrent MSE test: {} - Best MSE: {} - epoch: {}".format(tot_mse_loss_test, best_mse_test, best_epoch))
        if org_data == False:
          print("Current MSE test org: {} - Best MSE org: {} - epoch: {}".format(tot_mse_loss_test_org, best_mse_test_org, best_epoch))
        print("Current loss test: {} - Best loss test: {} - epoch: {}".format(total_epoch_loss_test, best_elbo_test, best_epoch))
        
        file_log.write("\nCurrent MSE test: {} - Best MSE: {} - epoch: {}\n".format(tot_mse_loss_test, best_mse_test, best_epoch))
        if org_data == False:
          file_log.write("Current MSE test org: {} - Best MSE org: {} - epoch: {}\n".format(tot_mse_loss_test_org, best_mse_test_org, best_epoch))
        file_log.write("Current loss test: {} - Best loss test: {} - epoch: {}\n".format(total_epoch_loss_test, best_elbo_test, best_epoch))
        file_log.flush()

    # Save last model param
    model_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
    save_checkpoint(model_state_dict, guide, optimizer, scheduler, epoch, path_last_model, path_last_model_params)

# --------------------- Prediction phase ---------------------
pyro.clear_param_store()
del model
del guide

print("\n------------- Prediction phase -------------")
file_log.write("\n------------- Prediction phase -------------\n")
file_log.flush()

# Create a model and a guide, both as (Pyro)Modules.
model: torch.nn.Module = Model( len_dataset=len(train_loader.dataset), input_features=dict_model_config["input_features"], seq_len=max_model_order, 
                                embedding_dim=dict_model_config["embedding_dim"], num_enc_block=dict_model_config["num_enc_block"], num_heads=dict_model_config["num_heads"], 
                                h_enc_layer=dict_model_config["h_enc_layer"],  dropout=dict_model_config["dropout"], h1=dict_model_config["h1"], h2=dict_model_config["h2"], 
                                add_positional_encoding=dict_model_config["add_positional_encoding"],
                                add_temporal_embedding=dict_model_config["add_temporal_embedding"],
                                sigma_layer=dict_model_config["sigma_layer"], device=device).to(device)

model_state_dict, guide, _, _, epoch = load_checkpoint(path_best_model)
model.load_state_dict(model_state_dict)

pyro.get_param_store().load(path_best_model_params)
model.to(device)

print("Best epoch: {}\n".format(best_epoch-1))
file_log.write("Best epoch: {}\n\n".format(best_epoch-1))
file_log.flush()

model.eval()

# Predictive distribution
# num_samples_pred: Number of realizations for each timestamp t
predictive = Predictive(model, guide=guide, num_samples=dict_exp_config["num_samples_pred"])

for key in info_cod_stations.keys():
  
  train_dataset = dict_datasets_train[key]

  current_cod_station = dict_datasets_test[key].cod_station
  x_test = dict_datasets_test[key].x_tensor
  y_test = dict_datasets_test[key].y_tensor
  dates_test = dict_datasets_test[key].get_list_info_date()
  
  region = info_cod_stations[key][2]
  start_year = info_cod_stations[key][3]
  end_year = info_cod_stations[key][4]

  min_cod_station = info_cod_stations[key][5]
  max_cod_station = info_cod_stations[key][6]

  start_date_current_cod_station = datetime(start_year, 1, 1, 0, 0)
  end_date_current_cod_station = datetime(end_year+1, 1, 1, 0, 0)

  mean_cod_station = train_dataset.mean
  std_cod_station = train_dataset.std
  
  path_dir_plots_current_station = joinpath(path_dir_plots, current_cod_station)
  os.makedirs(path_dir_plots_current_station, exist_ok=True)

  x_train_to_use = train_dataset.x_train_to_use
  y_train_to_use = train_dataset.y_train_to_use
  x_train = train_dataset.x_train
  y_train = train_dataset.y_train
  x_train_tensor = train_dataset.x_train_tensor

  model_order = train_dataset.model_order

  idx_complete_ts = all_idx_missing_values_complete[key]

  mask = train_dataset.mask
  mask = torch.from_numpy(mask).to(device)

  dates_train = train_dataset.get_list_info_date()
  idx_missing_values = train_dataset.idx_missing_values

  path_dir_plots_current_station_calibr_plot = joinpath(path_dir_plots_current_station, "Calibration Plot")
  os.makedirs(path_dir_plots_current_station_calibr_plot, exist_ok=True)
  
  y_pred_mean_train, y_pred_median_train, ci_train, \
  y_pred_mean_test, y_pred_median_test, ci_test, \
  x_train, x_test = \
    prediction_phase( current_cod_station, region, start_year, end_year, predictive, 
                      x_train, y_train, y_test, x_test,
                      min_cod_station, max_cod_station, 
                      mean_cod_station, std_cod_station,
                      mask, dates_train, dates_test,
                      path_csv_errors_cod_stations, device, USE_CUDA,
                      org_data, standardization_data,
                      dict_exp_config["num_samples_pred"],
                      dict_exp_config["delta_conf_inter"],
                      dict_exp_config["confidence_level"],
                      path_dir_plots_current_station_calibr_plot,
                      path_csv_calibration_errors_cod_stations_train,
                      path_csv_calibration_errors_cod_stations_test,
                      path_csv_ence_errors_cod_stations_train,
                      path_csv_ence_errors_cod_stations_test,
                      model_name
              )
  
  if dict_exp_config["type_prediction"] == "mean":
    y_pred_train = y_pred_mean_train
    y_pred_test = y_pred_mean_test

  # --------------------- Imputation phase ---------------------
  imputation(
                x_train, y_train, y_pred_train, idx_missing_values,
                model_order, max_model_order, key, all_new_x_train, all_new_y_train,
            )
    
  # --------------------- Reconstruction complete time series ---------------------
  N = all_idx_train[key].shape[0] + all_idx_test[key].shape[0] + model_order

  old_complete_ts_current_cod_station, new_complete_ts_current_cod_station, new_complete_ts_current_cod_station_ci = \
    reconstruct_complete_ts(  x_train, y_train, ci_train, x_test, y_test, ci_test,
                              all_new_x_train[key], all_new_y_train[key], all_idx_train[key], all_idx_test[key], 
                              min_cod_station, max_cod_station, N, model_order, max_model_order,
                              org_data,
                          )
  
  # --------------------- Plot complete time series ---------------------
  path_dir_plots_current_station_ts = joinpath(path_dir_plots_current_station, "Time Series")
  os.makedirs(path_dir_plots_current_station_ts, exist_ok=True)
  
  plot_ts(
              current_cod_station, 
              old_complete_ts_current_cod_station, 
              new_complete_ts_current_cod_station, 
              new_complete_ts_current_cod_station_ci,
              idx_complete_ts, path_dir_plots_current_station_ts,
              start_date, end_date,
              start_date_current_cod_station, end_date_current_cod_station, 
              dict_limit_air_pollutants, air_poll_selected, delta_time, co_in_ug_m3,
              dict_exp_config["confidence_level"],
              model_name
            )

  # Plot training set time series segments
  plot_ts_segments(
                    current_cod_station, 
                    x_train, 
                    y_train,
                    y_pred_train, 
                    ci_train, 
                    idx_complete_ts,
                    start_date, 
                    min_cod_station, 
                    max_cod_station, 
                    model_order, 
                    air_poll_selected, 
                    co_in_ug_m3, path_dir_plots_current_station_ts, 
                    dict_exp_config["n_segments"],
                    dict_exp_config["confidence_level"],
                    min_cod_station, max_cod_station, 
                    mean_cod_station, std_cod_station,
                    standardization_data,
                    org_data,
                    train = True,
                    model_name = model_name
            )
  
  # Plot test set time series segments
  plot_ts_segments(
                    current_cod_station, 
                    x_test, 
                    y_test,
                    y_pred_test, 
                    ci_test, 
                    idx_complete_ts,
                    start_date, 
                    min_cod_station, 
                    max_cod_station, 
                    model_order, 
                    air_poll_selected, 
                    co_in_ug_m3, path_dir_plots_current_station_ts, 
                    dict_exp_config["n_segments"],
                    dict_exp_config["confidence_level"],
                    min_cod_station, max_cod_station, 
                    mean_cod_station, std_cod_station,
                    standardization_data,
                    org_data,
                    train = False,
                    model_name = model_name
            )

  if eval_model == False:
    with open(path_cod_stations_processed, 'a') as f:
      f.write(current_cod_station + "\n")

  '''
  else:
  
    save_csv_ts(current_cod_station, new_complete_ts_current_cod_station,
                new_complete_ts_current_cod_station_std, N, path_csv_ts_cod_stations,
                start_date, delta_time, air_poll_selected)
  '''

path_dir_plots_losses = joinpath(path_dir_plots, "Losses")
os.makedirs(path_dir_plots_losses, exist_ok=True)

path_tot_loss_training = joinpath(path_dir_plots_losses, "tot_loss_train.png")
plot_loss(tot_train_elbo, "Tot loss training set", path_tot_loss_training)

path_tot_loss_test = joinpath(path_dir_plots_losses, "tot_loss_test.png")
plot_loss(test_elbo, "Tot loss test set", path_tot_loss_test)

if eval_model:
  # Save dictionaries of all cod stations selected
  path_dir_results_datasets = joinpath(path_dir_model_config, "pickles")
  os.makedirs(path_dir_results_datasets, exist_ok=True)

  save_pickle_datasets( 
                        all_idx_train, all_idx_test, all_idx_missing_values_train,
                        all_idx_missing_values_complete, all_new_x_train, all_new_y_train, 
                        all_x_test, all_y_test,
                        info_cod_stations, path_dir_results_datasets
                      )
    
pyro.clear_param_store()
del model
del guide

clear_output()