# **Start**

In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.10.0+cu111.html
!pip install torch-geometric
!pip install torch-geometric-temporal

In [None]:
!pip install pytorch_lightning
!pip install -qqq wandb

In [None]:
!pip install torch==1.10.0+cu111 torchvision==0.11.0+cu111 torchaudio==0.10.0 -f https://download.pytorch.org/whl/torch_stable.html


In [None]:
""" 

Substitute this two lines in dataloader.py ad delete the 2 corrispective imports of torch._six

import collections.abc as container_abcs
int_classes = int 

"""

In [2]:
import torch
import collections.abc as container_abcs
from torch.nn import functional as F
import pytorch_lightning as pl
from torch_geometric_temporal.nn.recurrent import GConvLSTM


In [3]:

from torch_geometric_temporal.signal import temporal_signal_split
from torch_geometric_temporal.signal import StaticGraphTemporalSignal


from pytorch_lightning.loggers import WandbLogger
import wandb
import matplotlib.pyplot as plt
import PIL
import numpy as np
from time import strftime
import torch.nn as nn
import io
import json
from sklearn.metrics import mean_absolute_error as MAE




In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# **Utils**

In [6]:
import torch
import random
import wandb
import io
import PIL
import torch.autograd as autograd
from torch_geometric_temporal.signal import StaticGraphTemporalSignal
from pylab import rcParams
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from palettable.cartocolors.sequential import SunsetDark_6


def getVariablesClass(inst):
    var = []
    cls = inst.__class__
    for v in cls.__dict__:
        if not callable(getattr(cls, v)):
            var.append(v)
    return var


def get_hyperparams_dict(p):
    params_list = getVariablesClass(p)
    d = dict()
    for pos, var in enumerate(params_list):
        d[var] = getattr(p, var)
    return d


def degrade_dataset(X, missingness, rand, v):
    """
    Inputs:
        X : dataset to corrupt
        missingness : % of data to eliminate[0,1]
        rand : random state
        v : replace with = 'zero' or 'nan'
      Outputs:
        corrupted Dataset
        binary mask
    """
    x_temp = X.clone()
    X_1d = x_temp.flatten()  # X.shape = #lags x #station
    n = len(X_1d)
    # mask_1d = torch.ones(n)
    mask_1d = torch.zeros(n)

    corrupt_ids = random.sample(range(n), int(missingness * n))
    for i in corrupt_ids:
        X_1d[i] = v
        # mask_1d[i] = 0
        mask_1d[i] = 1

    cX = X_1d.reshape(X.shape)
    mask = mask_1d.reshape(X.shape)
    mask = mask.byte()

    return cX, mask


def check_mask(x, device, v=-1):
    """
    Inputs:
        x : data batch corrupted
    Outputs:
        mask : 1=corrupted, 0=original # nel paper fa l'opposto
    """
    x.to(device)
    if x.dim() == 1:
        dim_x = x.shape[0]
    else:
        dim_x = x.shape[1]
    ones = torch.ones(dim_x).to(device)
    zeros = torch.zeros(dim_x).to(device)
    mask = torch.where(x != v, zeros, ones)
    return mask


def get_only_day_data(y, h, device):
    """
    Inputs:
        y : data to rpedict
    Outputs:
        night_mask : 1=day, 0=night
    """
    y.to(device)
    h.to(device)
    dim_y = y.shape
    ones = torch.ones(dim_y).to(device)
    zeros = torch.zeros(dim_y).to(device)
    mask = torch.where(y == 0.0, zeros, ones)
    h = mask * h

    return h


def buffer_plot_and_get(fig):
    """
    Util function for visualization output
    """
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    return PIL.Image.open(buf)


def visualization_output_imputation(features, targets, predictions, predictions_rec, params, station_name):
    """
    Create images from features, preditions and targets data for data imputation task
    """
    x_axis_feature = np.arange(params.LAGS)
    fig, ax = plt.subplots(figsize=(15, 15))
    ax.plot(x_axis_feature, features.cpu(), "-g", label="features_" + station_name)
    ax.plot(x_axis_feature, targets.cpu(), "-b", label="target_" + station_name)
    ax.plot(x_axis_feature, predictions_rec.cpu(), "-r", label="predicted_rec_" + station_name)
    ax.plot(x_axis_feature, predictions.cpu(), "-y", label="predicted_" + station_name)
    plt.title(station_name)
    pil_image = buffer_plot_and_get(fig)
    return pil_image, plt


def visualization_output_imputation_F(x, y_f, h_f, params, station_name):
    """
    Create images from features, preditions and targets data for data imputation task
    """
    x_axis_feature = np.arange(params.LAGS)
    x_axis_target = np.arange(params.LAGS - 1, params.LAGS + params.PREDICTION_WINDOW)
    target_sequence = torch.cat((x[-1].unsqueeze(dim=0).cpu(), y_f.cpu()))
    predicted_sequence = torch.cat((x[-1].unsqueeze(dim=0).cpu(), h_f.cpu()))
    fig, ax = plt.subplots(figsize=(25, 15))

    ax.plot(x_axis_feature, x.cpu(), linestyle='-', color='b', label="y_i_" + station_name)
    ax.plot(x_axis_target, target_sequence.cpu(), linestyle='--', color="b", label="y_f_" + station_name)
    ax.plot(x_axis_target, predicted_sequence.cpu(), linestyle='-', color="r", label="h_f_" + station_name)
    ax.plot([x_axis_feature[-1], x_axis_feature[-1]], [-0.2, 1], linestyle='-.', color='k',
            label="limit_" + station_name)

    plt.title(station_name)
    plt.xlabel("Hours")
    # plt.ylabel("Output")
    pil_image = buffer_plot_and_get(fig)

    return pil_image, plt


def visualization_output_imputation_IF(x_c, y_i, h_i, h_i_rec, y_f, h_f, params, station_name):
    """
    Create images from features, preditions and targets data for data imputation task
    """
    x_axis_feature = np.arange(params.LAGS)
    x_axis_target = np.arange(params.LAGS - 1, params.LAGS + params.PREDICTION_WINDOW)
    target_sequence = torch.cat((y_i[-1].unsqueeze(dim=0).cpu(), y_f.cpu()))
    predicted_sequence = torch.cat((y_i[-1].unsqueeze(dim=0).cpu(), h_f.cpu()))
    h_i = torch.cat((h_i.cpu(), y_f[0].unsqueeze(dim=0).cpu()))
    # plt.stem(x_axis_feature, np.array(x_c))
    fig, ax = plt.subplots(figsize=(25, 15))

    # Combination of style #1

    # ax.plot(x_axis_feature, x_c.cpu(), 'o-', color='g',  label="x_c_" + station_name) #'o-',
    # ax.plot(x_axis_feature, y_i.cpu(), 'o-', color='b', label="y_i_" + station_name)
    # ax.plot(x_axis_feature_1, h_i.cpu(), 'o-', color="y", label="h_i_" + station_name)
    # ax.plot(x_axis_target, target_sequence.cpu(), 'o-', color="b", label="y_f_" + station_name)
    # ax.plot(x_axis_target, predicted_sequence.cpu(), 'o-', color="r", label="h_f_" + station_name)
    # ax.plot(x_axis_feature, h_i_rec.cpu(), 'o-', color="r", label="h_i_rec_" + station_name)

    # Combination of style #2
    # ax.plot(x_axis_feature, x_c.cpu(), linestyle='--', color='g',  label="x_c_" + station_name) #'o-', 'r*'
    mask = check_mask(x_c.cpu(), 'cpu').byte()
    x_axis_feature_tensor = torch.tensor(x_axis_feature)
    x_axis_corrupted = x_axis_feature_tensor[mask]
    ax.plot(x_axis_feature, y_i.cpu(), linestyle='--', color='b', label="y_i_" + station_name)
    # ax.plot(x_axis_feature_1, h_i.cpu(), linestyle='-.', color="y", label="h_i_" + station_name) # all x predicted
    ax.plot(x_axis_target, target_sequence.cpu(), linestyle='--', color="b", label="y_f_" + station_name)
    ax.plot(x_axis_target, predicted_sequence.cpu(), linestyle='-', color="r", label="h_f_" + station_name)
    ax.plot(x_axis_feature, h_i_rec.cpu(), linestyle='-', color="r", label="h_i_rec_" + station_name)
    ax.plot(x_axis_corrupted, np.ones(len(x_axis_corrupted)) * -0.001, 'kP', label="x_c_" + station_name)  # 'o-',
    ax.plot([x_axis_feature[-1], x_axis_feature[-1]], [-0.2, 1], linestyle='-.', color='k',
            label="limit_" + station_name)

    plt.title(station_name)
    plt.xlabel("Hours")
    # plt.ylabel("Output")
    pil_image = buffer_plot_and_get(fig)
    return pil_image, plt


def load_prediction_data(x, h, y, params, mask, phase):
    """
    Inputs:
        x : data batch corrupted
        h : predictions
        y : targets
    Outputs:
        Load images on wandb logger
    """

    index = random.randint(0, params.NUM_STATION - 1)
    mask = (mask[index, :]).float()
    features = x[index, :]
    predictions = h[index, :]
    targets = y[index, :]
    predictions_reconstructed = torch.mul(mask, predictions) + torch.mul(1 - mask, features)
    image_plot, plt_plot = visualization_output_imputation(features, targets, predictions, predictions_reconstructed,
                                                           params,
                                                           station_name=str(index))
    wandb.log({phase + "_plot": plt_plot})
    wandb.log({phase + "_images": wandb.Image(image_plot)})


def load_prediction_data_IF(x_c, h_i, y_i, y_f, h_f, params, mask, phase):
    """
    Inputs:
        x : data batch corrupted
        h : predictions
        y : targets
    Outputs:
        Load images on wandb logger
    """

    index = random.randint(0, params.NUM_STATION - 1)
    mask = (mask[index, :]).float()
    features_corrupted = x_c[index, :]
    predictions_i = h_i[index, :]
    targets_i = y_i[index, :]
    predictions_f = h_f[index, :]
    targets_f = y_f[index, :]
    predictions_reconstructed_i = torch.mul(mask, predictions_i) + torch.mul(1 - mask, targets_i)
    image_plot, plt_plot = visualization_output_imputation_IF(features_corrupted, targets_i, predictions_i,
                                                              predictions_reconstructed_i,
                                                              targets_f,
                                                              predictions_f,
                                                              params,
                                                              station_name=str(index))
    wandb.log({phase + "_plot": plt_plot})
    wandb.log({phase + "_images": wandb.Image(image_plot)})


def load_prediction_data_F(x, y_f, h_f, params, phase):
    """
    Inputs:
        x : data batch corrupted
        h : predictions
        y : targets
    Outputs:
        Load images on wandb logger
    """

    index = random.randint(0, params.NUM_STATION - 1)
    features_f = x[index, :]
    predictions_f = h_f[index, :]
    targets_f = y_f[index, :]
    image_plot, plt_plot = visualization_output_imputation_F(features_f,
                                                             targets_f,
                                                             predictions_f,
                                                             params,
                                                             station_name=str(index))
    wandb.log({phase + "_plot": plt_plot})
    wandb.log({phase + "_images": wandb.Image(image_plot)})


def load_prediction_data_F2(x, y_f, h_f, params, phase):
    """
    Inputs:
        x : data batch corrupted
        h : predictions
        y : targets
    Outputs:
        Load images on wandb logger
    """

    index = random.randint(0, params.NUM_STATION - 1)
    features_f = x[index, :]
    predictions_f = h_f[index, :]
    targets_f = y_f[index, :]
    image_plot, plt_plot = visualization_output_imputation_F2(features_f,
                                                              targets_f,
                                                              predictions_f,
                                                              params,
                                                              station_name=str(index))
    # wandb.log({phase + "_plot": plt_plot})
    wandb.log({phase + "_images": wandb.Image(image_plot)})


def visualization_output_imputation_F2(x, y_f, h_f, params, station_name):
    """
    Create images from features, preditions and targets data for data imputation task
    """

    params_img = {
        'axes.labelsize': 10,
        'font.family': 'Times New Roman',
        'font.size': 11,
        'legend.fontsize': 9,
        'xtick.labelsize': 9,
        'ytick.labelsize': 9,
        'text.usetex': False,
        'figure.figsize': [4.7, 3.2],
        'lines.linewidth': 1.2
    }

    rcParams.update(params_img)
    matplotlib.rc('pdf', fonttype=42)
    # matplotlib.font_manager._rebuild()

    x_axis_feature = np.arange(params.LAGS)
    x_axis_target = np.arange(params.LAGS - 1, params.LAGS + params.PREDICTION_WINDOW)
    target_sequence = torch.cat((x[-1].unsqueeze(dim=0).cpu(), y_f.cpu()))
    predicted_sequence = torch.cat((x[-1].unsqueeze(dim=0).cpu(), h_f.cpu()))

    colors = SunsetDark_6.mpl_colors
    fig, ax = plt.subplots()
    ax.yaxis.grid(linewidth=0.5, alpha=0.3)

    # plt.plot(np.ones(10), color=colors[0], label='Prova')
    plt.plot(x_axis_feature, x.cpu(), linestyle='-', color='b', label="S_x")  # Input sequence"
    plt.plot(x_axis_target, target_sequence.cpu(), linestyle='--', color="b", label="S_y")  # Target sequence
    plt.plot(x_axis_target, predicted_sequence.cpu(), linestyle='-', color="r", label="S_h")  # Predicted sequence
    plt.plot([x_axis_feature[-1], x_axis_feature[-1]], [-0.2, 1], linestyle='-.',
             color='k')  # , label="Current timestep")

    # plt.xlim(0, 9)
    # plt.ylim(0, 2)
    plt.xlabel('Hours')
    plt.ylabel('Normalized power')
    ax.legend()
    plt.tight_layout()
    # plt.show()
    id = random.randint(1, 1000000)
    if params.SAVE_IMGS:
        fig.savefig('imgs/' + str(id) + '_prova.pdf', bbox_inches='tight', pad_inches=0)

    # fig, ax = plt.subplots(figsize=(25, 15))

    # ax.plot(x_axis_feature, x.cpu(), linestyle='-', color='b', label="y_i_" + station_name)
    # ax.plot(x_axis_target, target_sequence.cpu(), linestyle='--', color="b", label="y_f_" + station_name)
    # ax.plot(x_axis_target, predicted_sequence.cpu(), linestyle='-', color="r", label="h_f_" + station_name)
    # ax.plot([x_axis_feature[-1], x_axis_feature[-1]], [-0.2, 1], linestyle='-.', color='k',
    #         label="limit_" + station_name)

    plt.title(station_name)
    # plt.xlabel("Hours")
    # plt.ylabel("Output")
    pil_image = buffer_plot_and_get(fig)

    return pil_image, plt


def hard_gradient_penalty(net, real_data, fake_data, device):
    mask = torch.FloatTensor(real_data.shape).to(device).uniform_() > 0.5
    inv_mask = ~mask
    mask, inv_mask = mask.float(), inv_mask.float()

    interpolates = mask * real_data + inv_mask * fake_data
    interpolates = interpolates.to(device)
    interpolates = autograd.Variable(interpolates, requires_grad=True)
    c_interpolates = net(interpolates)

    gradients = autograd.grad(
        outputs=c_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(c_interpolates.size()).to(device),
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]

    gradients = gradients.view(gradients.size(0), -1)
    gp = (gradients.norm(2, dim=1) - 1).pow(2).mean()
    return gp


# Cumulative loss display
def loss_used(params):
    loss_list = []

    # if params.IMPUTATION:
    #     loss_list.append('test_loss_imputation')
    #     loss_list.append('test_MAE_imputation')
    #     if params.FORECASTING:
    #         if params.MULTIVARIATE:
    #             loss_list.append('test_loss_IFM_power')
    #             loss_list.append('test_loss_IFM_temp')
    #             loss_list.append('test_loss_IFM_wind')
    #             loss_list.append('test_MAE_IFM')
    #         else:
    #             loss_list.append('test_loss_IFU')
    #             loss_list.append('test_MAE_IFU')
    # else:
    if params.FORECASTING:
        if params.MULTIVARIATE:
            loss_list.append('test_loss_FM_power')
            loss_list.append('test_MAE_FM')
            # loss_list.append('test_loss_FM_temp')
            # loss_list.append('test_loss_FM_wind')
        else:
            loss_list.append('test_loss_FU')
            loss_list.append('test_MAE_FU')

    return loss_list


def get_value_of_same_loss(list_of_dict, name):
    values = []
    for d in list_of_dict:
        values.append(d[name])
    return np.array(values)


def print_runs_results(params, RUNS, results):
    loss_names = loss_used(params)
    for name in loss_names:
        loss_values = get_value_of_same_loss(results, name)
        mean = loss_values.mean()
        std = loss_values.std()
        print("----", name, "----")
        print("Mean: ", mean, "Std: ", std)
        print("EXC_Mean: ", str(mean).replace('.', ','), "EXC_Std: ", str(std).replace('.', ','), '\n')
        # print(loss_values)


# Shuffle dataset
def temporal_signal_split_and_shuffle(data_iterator, shuffle=True, train_ratio: float = 0.8):
    if shuffle:
        train_snapshots = int(train_ratio * data_iterator.snapshot_count)
        total_lenght = data_iterator.snapshot_count
        feature_index = np.arange(total_lenght, dtype=int)
        np.random.shuffle(feature_index)
        train_index = feature_index[0:train_snapshots].tolist()
        test_index = feature_index[train_snapshots:].tolist()

        if type(data_iterator) == StaticGraphTemporalSignal:
            train_features = np.array(data_iterator.features)[train_index]
            train_targets = np.array(data_iterator.targets)[train_index]
            test_features = np.array(data_iterator.features)[test_index]
            test_targets = np.array(data_iterator.targets)[test_index]
            train_iterator = StaticGraphTemporalSignal(
                data_iterator.edge_index,
                data_iterator.edge_weight,
                train_features,
                train_targets,
            )

            test_iterator = StaticGraphTemporalSignal(
                data_iterator.edge_index,
                data_iterator.edge_weight,
                test_features,
                test_targets,
            )

    else:
        train_snapshots = int(train_ratio * data_iterator.snapshot_count)
        if type(data_iterator) == StaticGraphTemporalSignal:
            train_iterator = StaticGraphTemporalSignal(
                data_iterator.edge_index,
                data_iterator.edge_weight,
                data_iterator.features[0:train_snapshots],
                data_iterator.targets[0:train_snapshots],
            )

            test_iterator = StaticGraphTemporalSignal(
                data_iterator.edge_index,
                data_iterator.edge_weight,
                data_iterator.features[train_snapshots:],
                data_iterator.targets[train_snapshots:],
            )

    return train_iterator, test_iterator


# Define run name
def get_run_name(db, input_ws, output_ws, params):
    # Test
    test = '-'
    if params.FORECASTING:
        test = 'F'
        if params.MULTIVARIATE:
            test += 'M'

    name_run = "%s_%s_%s_%s_%s" % (params.GNN_MODEL, test, db, input_ws, output_ws)
    return name_run


# **Models**

In [7]:
import torch
from torch import nn
import pytorch_lightning as pl
from torch_geometric_temporal.nn.recurrent import GConvLSTM


class GNN_Forecasting(pl.LightningModule):
    """
    Model for doing forecasting on temporal series data
    Functions:
        __init__
        forward
    """

    def __init__(self, params):
        super(GNN_Forecasting, self).__init__()
        self.params = params
        self.recurrent = GConvLSTM(self.params.INPUT_FEATURE_DIMENSION, self.params.FILTERS,
                                   self.params.FILTER_SIZE)
        self.recurrent2 = GConvLSTM(self.params.FILTERS, self.params.FILTERS,
                                    self.params.FILTER_SIZE)
        self.tanh = torch.nn.Tanh()
        self.relu = torch.nn.ReLU()
        self.linear = torch.nn.Linear(self.params.INPUT_MLP_DIMENSION, self.params.OUTPUT_MLP_DIMENSION_FORECASTING)

    def forward(self, x, edge_index, edge_weight):
        h_0 = torch.zeros(x.shape[0], self.params.FILTERS).to(x.device)
        c_0 = torch.zeros(x.shape[0], self.params.FILTERS).to(x.device)
        h_1 = torch.zeros(x.shape[0], self.params.FILTERS).to(x.device)
        c_1 = torch.zeros(x.shape[0], self.params.FILTERS).to(x.device)
        x = torch.reshape(x, (-1, self.params.INPUT_FEATURE_DIMENSION, self.params.LAGS))
        for i in range(self.params.LAGS):
            x_t = x[:, :, i]
            h_0, c_0 = self.recurrent(x_t, edge_index, edge_weight, H=h_0, C=c_0)
            h_0 = self.relu(h_0)
            h_1, c_1 = self.recurrent2(h_0, edge_index, edge_weight, H=h_1, C=c_1)
            h_1 = self.tanh(h_1)
        h = self.linear(h_1)
        return h


class LSTM_Forecasting(pl.LightningModule):
    def __init__(self, params):
        super(LSTM_Forecasting, self).__init__()
        self.params = params

        self.rnn = nn.LSTM(self.params.INPUT_FEATURE_DIMENSION,
                           self.params.HIDDEN_DIMENSION_SINGLE,
                           num_layers=self.params.NUMBER_LSTM_LAYERS,
                           batch_first=True)

        self.linear = torch.nn.Linear(self.params.HIDDEN_DIMENSION_SINGLE * self.params.LAGS,
                                      self.params.OUTPUT_MLP_DIMENSION_FORECASTING)

    def forward(self, x):
        if self.params.MULTIVARIATE:
            x = torch.reshape(x, (x.size(0), 3, self.params.LAGS))  # batch, feat, seq
        else:
            x = torch.unsqueeze(x, dim =1)
        x = torch.transpose(x, 1, 2)  # batch, seq, feat
        h, _ = self.rnn(x)  # batch, seq, hid
        h = h.reshape(h.size(0), -1)
        h = self.linear(h)
        return h


class CNN_Forecasting(pl.LightningModule):
    def __init__(self, params):
        super(CNN_Forecasting, self).__init__()
        self.params = params
        # CONV1
        kernel_size = 5
        in_channels = self.params.INPUT_FEATURE_DIMENSION
        out_channels = 8
        in_features = 24
        out_features = 20
        padding = int((out_features - in_features + kernel_size - 1) / 2)
        self.cnn1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.relu = nn.ReLU()
        in_features_pool = 20
        out_features_pool = 18
        kernel_pool = 3
        padding_pool = int((out_features_pool - in_features_pool + kernel_pool - 1) / 2)
        self.maxPool1d = nn.MaxPool1d(kernel_pool, padding=padding_pool, stride=1)
        self.relu = nn.ReLU()

        # #CONV2
        kernel_size = 5
        in_channels = 8
        out_channels = 16
        in_features = 18
        out_features = 16
        padding = int((out_features - in_features + kernel_size - 1) / 2)
        self.cnn2 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        # #CONV3
        kernel_size = 3
        in_channels = 16
        out_channels = 24
        in_features = 16
        out_features = 14
        padding = int((out_features - in_features + kernel_size - 1) / 2)
        self.cnn3 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        # #CONV4
        kernel_size = 3
        in_channels = 24
        out_channels = 32
        in_features = 14
        out_features = 12
        padding = int((out_features - in_features + kernel_size - 1) / 2)
        self.cnn4 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)

        # LINEAR
        self.mlp = nn.Linear(out_channels*out_features, self.params.OUTPUT_MLP_DIMENSION_FORECASTING)

    def forward(self, x):
        if self.params.MULTIVARIATE:
            x = torch.reshape(x, (-1, self.params.INPUT_FEATURE_DIMENSION, self.params.LAGS))
        else:
            x = torch.unsqueeze(x, dim=1)
        h = self.maxPool1d(self.relu(self.cnn1(x)))
        h = self.relu(self.cnn2(h))
        h = self.relu(self.cnn3(h))
        h = self.relu(self.cnn4(h))
        h = h.reshape(h.size(0), -1)  # (1,1920)
        h = self.mlp(h)
        return h




# **Train**

In [8]:
from sklearn.preprocessing import MinMaxScaler
import json
from sklearn.metrics import mean_absolute_error as MAE

import warnings

warnings.filterwarnings("ignore", category=UserWarning)

from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
from torch.utils.data import Dataset


class Dataset_custom(Dataset):

    def __init__(self, params):
        self.params = params
        self._read_json_data()
        self.lags = None
        self.features = None
        self.features_corrupted = None
        self.targets = None
        self.features_temperatures = None
        self.targets_temperatures = None
        self.features_winds = None
        self.targets_winds = None
        self.number_of_station = None
        self.encoded_data = []
        self.read_dataset(self.params.LAGS)

    def _read_json_data(self):
        file = self.params.PATH_DATASET
        with open(file) as f:
            self._dataset = json.load(f)

    def _get_edges(self):
        self._edges = np.array(self._dataset["edges"]).T

    def _get_edge_weights(self):
        self._edge_weights = np.array(self._dataset["weights"]).T

    def _get_targets_and_features(self):
        # Power
        stacked_target = np.stack(self._dataset["block"])
        scaler = MinMaxScaler()
        scaler.fit(stacked_target)
        standardized_target = scaler.transform(stacked_target)
        # Temperature
        stacked_temp = np.stack(self._dataset["block_temp"])
        scaler = MinMaxScaler()
        scaler.fit(stacked_temp)
        standardized_temp = scaler.transform(stacked_temp)
        # Wind
        stacked_wind = np.stack(self._dataset["block_wind"])
        scaler = MinMaxScaler()
        scaler.fit(stacked_wind)
        standardized_wind = scaler.transform(stacked_wind)
        # # Month
        # stacked_month = np.stack(self._dataset["block_month"])
        # scaler = MinMaxScaler()
        # scaler.fit(stacked_month)
        # standardized_month = scaler.transform(stacked_month)
        # # Hour
        # stacked_hour = np.stack(self._dataset["block_hour"])
        # scaler = MinMaxScaler()
        # scaler.fit(stacked_hour)
        # standardized_hour = scaler.transform(stacked_hour)

        self.number_of_station = stacked_target.shape[1]

        self.features = [
            # np.concatenate((standardized_target[i: i + self.lags, :].T,
            #                 standardized_temp[i: i + self.lags, :].T,
            #                 standardized_wind[i: i + self.lags, :].T,
            #                 standardized_month[i: i + self.lags, :].T,
            #                 standardized_hour[i: i + self.lags, :].T), axis=-1)
            np.concatenate((standardized_target[i: i + self.lags, :].T,
                            standardized_temp[i: i + self.lags, :].T,
                            standardized_wind[i: i + self.lags, :].T), axis=-1)

            # list of (4, 3, 24)
            for i in range(standardized_target.shape[0] - self.lags - self.params.PREDICTION_WINDOW)
        ]
        self.features = self.features[:2300]

        self.targets = [
            # np.concatenate((standardized_target[i:i + self.params.PREDICTION_WINDOW, :].T,
            #                 standardized_temp[i:i + self.params.PREDICTION_WINDOW, :].T,
            #                 standardized_wind[i:i + self.params.PREDICTION_WINDOW, :].T,
            #                 standardized_month[i:i + self.params.PREDICTION_WINDOW, :].T,
            #                 standardized_hour[i:i + self.params.PREDICTION_WINDOW, :].T), axis=-1)
            np.concatenate((standardized_target[i:i + self.params.PREDICTION_WINDOW, :].T,
                            standardized_temp[i:i + self.params.PREDICTION_WINDOW, :].T,
                            standardized_wind[i:i + self.params.PREDICTION_WINDOW, :].T), axis=-1)

            for i in range(self.lags, standardized_target.shape[0] - self.params.PREDICTION_WINDOW)
        ]
        self.targets = self.targets[:2300]

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.encoded_data[idx]

    def read_dataset(self, lags) -> StaticGraphTemporalSignal:
        self.lags = lags
        self._get_edges()
        self._get_edge_weights()
        self._get_targets_and_features()
        for i in range(len(self.features)):
            self.encoded_data.append(Data(x=torch.FloatTensor(self.features[i]),
                                          edge_index=torch.LongTensor(self._edges),
                                          edge_attr=torch.FloatTensor(self._edge_weights),
                                          y=torch.FloatTensor(self.targets[i])))


class DataModule(pl.LightningDataModule):
    def __init__(self, params):
        super().__init__()
        self.params = params
        self.num_station = None
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        dataset = Dataset_custom(self.params)
        self.num_station = dataset.number_of_station  # len(loader.features[0])

        len_dataset = len(dataset)
        train_ratio = 0.7
        val_test_ratio = 0.5
        train_snapshots = int(train_ratio * len_dataset)
        val_test_snapshots = len_dataset - train_snapshots
        val_snapshots = int(val_test_ratio * val_test_snapshots)
        test_snapshots = len_dataset - train_snapshots - val_snapshots
        tr_db, val_db, te_db = torch.utils.data.random_split(dataset, [train_snapshots, val_snapshots, test_snapshots])
        self.train_loader = DataLoader(tr_db, batch_size=self.params.BATCH_SIZE, shuffle=True)
        self.val_loader = DataLoader(val_db, batch_size=self.params.BATCH_SIZE)
        self.test_loader = DataLoader(te_db, batch_size=self.params.BATCH_SIZE)

    # def setup(self, stage=None):

    def get_len_loader(self):
        len_train = 0
        len_val = 0
        len_test = 0
        for _ in self.train_loader: len_train += 1
        for _ in self.train_loader: len_val += 1
        for _ in self.train_loader: len_test += 1
        return len_train, len_val, len_test

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

    def test_dataloader(self):
        return self.test_loader


class TrainingModule(pl.LightningModule):
    def __init__(self, params):
        super().__init__()
        self.params = params

        # Index
        self.index = 0
        self.train_min_length = None
        self.val_min_length = None
        # Lr
        self.learning_rate = self.params.LR
        self.automatic_optimization = False

        if self.params.FORECASTING:
            # Models
            self.GNN_Forecasting = GNN_Forecasting(params)
            #self.LSTM_Forecasting = LSTM_Forecasting(params)
            #self.CNN_Forecasting = CNN_Forecasting(params)
            self.mse_f = nn.MSELoss()

    """ Index """

    def on_train_epoch_start(self):  # on_epoch_start
        self.train_min_length = min(self.params.LIMIT_TRAIN_BATCHES, self.params.LEN_TRAIN)
        self.val_min_length = min(self.params.LIMIT_VAL_BATCHES, self.params.LEN_VAL)
        self.index = 1 #np.random.randint(0, self.val_min_length - 1)

        # Updating LR
        if self.current_epoch % self.params.LR_ITER == 0 and self.current_epoch > 0:
            print('Updating learning rate at epoch: ', self.current_epoch)
            main_opt = self.optimizers()
            self.learning_rate = self.learning_rate * self.params.DECAY_LR
            for group in main_opt.param_groups:
                group["lr"] = self.learning_rate
        self.log('lr', self.learning_rate, on_step=False, on_epoch=True, prog_bar=True, logger=True)



    def forward(self, input, edge_index, edge_weight):
        output = self.GNN_Forecasting(input, edge_index, edge_weight)
        #output = self.LSTM_Forecasting(input)
        #output = self.CNN_Forecasting(input)
        return output

    """ Step """

    def training_step(self, train_batch, batch_idx):
        # Get data from batches
        if self.params.MULTIVARIATE:
            x = train_batch.x
        else:
            x = train_batch.x[:, :self.params.LAGS]
        y = train_batch.y[:, :self.params.PREDICTION_WINDOW]
        edge_index = train_batch.edge_index
        edge_weight = train_batch.edge_attr

        if self.params.FORECASTING:
            main_opt = self.optimizers()
            if self.params.MULTIVARIATE:
                y_predicted= self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                loss_forecasting_tot = loss_forecasting
                main_opt.zero_grad()
                self.manual_backward(loss_forecasting_tot)
                main_opt.step()
                self.log('train_loss_FM_power', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                train_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('train_MAE_FM', train_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)

            else:
                y_predicted = self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                main_opt.zero_grad()
                self.manual_backward(loss_forecasting)
                main_opt.step()
                self.log('train_loss_FU', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                train_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('train_MAE_FU', train_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)

    def validation_step(self, val_batch, batch_idx):
        # Get data from batches
        if self.params.MULTIVARIATE:
            x = val_batch.x
        else:
            x = val_batch.x[:, :self.params.LAGS]
        y = val_batch.y[:, :self.params.PREDICTION_WINDOW]
        x_power = x[:, :self.params.LAGS]
        edge_index = val_batch.edge_index
        edge_weight = val_batch.edge_attr

        if self.params.FORECASTING:
            if self.params.MULTIVARIATE:
                y_predicted = self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                self.log('val_loss_FM_power', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                val_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('val_MAE_FM', val_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)
            else:
                y_predicted = self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                self.log('val_loss_FU', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                val_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('val_MAE_FU', val_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)
            if (batch_idx == self.index) and self.params.LOGGER:
                load_prediction_data_F2(x_power, y, y_predicted, self.params, "Validation")

    def test_step(self, test_batch, batch_idx):
        # Get data from batches
        if self.params.MULTIVARIATE:
            x = test_batch.x
        else:
            x = test_batch.x[:, :self.params.LAGS]
        y = test_batch.y[:, :self.params.PREDICTION_WINDOW]
        x_power = x[:, :self.params.LAGS]
        edge_index = test_batch.edge_index
        edge_weight = test_batch.edge_attr

        if self.params.FORECASTING:
            if self.params.MULTIVARIATE:
                y_predicted = self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                self.log('test_loss_FM_power', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                test_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('test_MAE_FM', test_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)
            else:
                y_predicted = self.forward(x, edge_index, edge_weight)
                #y_predicted = get_only_day_data(y, y_predicted, self.params.DEVICE)
                loss_forecasting = self.mse_f(y_predicted, y)
                self.log('test_loss_FU', loss_forecasting, on_step=False, on_epoch=True, prog_bar=True,
                         logger=True, batch_size=1)
                test_mae = MAE(y_predicted.cpu().detach().numpy(), y.cpu().detach().numpy())
                self.log('test_MAE_FU', test_mae, on_step=False, on_epoch=True, prog_bar=True, logger=True,
                         batch_size=1)
            if (batch_idx == self.index) and self.params.LOGGER:
                load_prediction_data_F2(x_power, y, y_predicted, self.params, "Test")

    def configure_optimizers(self):
        main_opt = torch.optim.Adam(list(self.GNN_Forecasting.parameters()), lr=self.learning_rate)
        return main_opt

# **Main**

In [None]:
import time
from pytorch_lightning.callbacks import ModelCheckpoint
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb
import numpy as np
from time import strftime


# System hyperparameters
class HyperParameters():
    NAME_PROJECT = 'Prova'
    NAME_RUN = 'Run_' + strftime("%d/%m/%y") + '_' + strftime("%H:%M:%S")

    # Training parameters
    BATCH_SIZE = 4
    EPOCHS = 100
    LR = 1e-3
    LR_ITER = 10
    DECAY_LR = 0.9
    LAGS = 24
    PREDICTION_WINDOW = 24
    NUM_STATION = 0
    LEN_TRAIN = 0
    LEN_VAL = 0
    LEN_TEST = 0

    # Model Parameters
    GNN_MODEL = "GConvLSTM"
    NODE_FEATURES = 24  # LAGS  # 32
    FILTERS = 64  # 16
    FILTER_SIZE = 2
    DROPOUT = 0.1

    # MLP Parameters
    INPUT_MLP_DIMENSION = FILTERS
    OUTPUT_MLP_DIMENSION_IMPUTATION = PREDICTION_WINDOW
    OUTPUT_MLP_DIMENSION_FORECASTING = PREDICTION_WINDOW
    INPUT_FEATURE_DIMENSION = 1

    #LSTM Parameters
    HIDDEN_DIMENSION_SINGLE = 30
    NUMBER_LSTM_LAYERS = 2


    # Use case
    FORECASTING = False
    MULTIVARIATE = False

    # Training
    SAVE_IMGS = True
    DEBUG = False
    REPRODUCIBLE = False
    SEED = 42
    NUM_GPUS = 0
    if NUM_GPUS == 1:
        DEVICE = "cuda"
    else:
        DEVICE = "cpu"
    NUM_WORKERS = 2  # multiprocessing.cpu_count()
    FAST_DEV_RUN = False
    LOGGER = True
    BAR_REFRESH_RATE = 1
    GRADIENT_CLIP = 0
    ACCUMULATE_GRADIENT_BATCHES = 1
    LIMIT_TRAIN_BATCHES = 1.0  # 1.0 # 1500
    LIMIT_VAL_BATCHES = 1.0  # 1.0 # 400
    LIMIT_TEST_BATCHES = 1.0
    AUTO_LR_FIND = False
    CHECK_VAL_EVERY_N_EPOCH = 4

    # Dataset
    DATA_MAP1 = {'PV4': 'H:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_with_weigth_light_multivariate'
                       '.json',
                'PV31': 'H:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_with_weigth_multivariate_T50'
                        '.json',
                'PV31T': 'H:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_31_with_weigth_multivariate_and_time'# 
                         '.json',
                'PV10': 'H:\Il mio Drive\PhD ICT\Data\Real_time_series_output_3Months_with_weigth_multivariate_T150.json'}

    # Dataset
    DATA_MAP2 = {'PV4': 'G:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_with_weigth_light_multivariate'
                       '.json',
                'PV31': 'G:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_with_weigth_multivariate_T50'
                        '.json',
                'PV31T': 'G:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_31_with_weigth_multivariate_and_time'    
                         '.json',
                'PV10': 'G:\Il mio Drive\PhD ICT\Data\Real_time_series_output_3Months_with_weigth_multivariate_T150.json'}

    # Path
    PATH_DATASET = DATA_MAP2['PV4']

    # Other
    NOTES = ""


params = HyperParameters()
results = []

###################  EXPERIMENT SETUP #################################################################################
params.DEBUG = False
params.NAME_PROJECT = 'TEST_NOTTE'
params.EPOCHS = 100
params.GNN_MODEL = "GNN"

# Window
in_ws = 24
params.LAGS = in_ws
out_win = 24
params.PREDICTION_WINDOW = out_win
params.OUTPUT_MLP_DIMENSION_FORECASTING = out_win

# Task
params.FORECASTING = True
params.MULTIVARIATE = False
if params.MULTIVARIATE:
    params.INPUT_FEATURE_DIMENSION = 3

# Dataset
db = 'PV4'
#params.PATH_DATASET = params.DATA_MAP1[db]
#params.PATH_DATASET = 'H:\Il mio Drive\PhD ICT\Data\Generated_time_series_output_29_with_weigth_multivariate_and_time.json'
params.PATH_DATASET = '/content/drive/My Drive/PhD ICT/Data/Generated_time_series_output_with_weigth_light_multivariate.json'


# Runs
RUNS = 3
params.REPRODUCIBLE = False
params.EXP_NAME = get_run_name(db, in_ws, out_win, params)
params.LOGGER = False
#######################################################################################################################

if params.DEBUG:
    params.EPOCHS = 2
    params.LIMIT_TRAIN_BATCHES = 100
    params.LIMIT_VAL_BATCHES = 100
    params.LIMIT_TEST_BATCHES = 100
    params.LOGGER = False

for i in range(RUNS):
    if params.REPRODUCIBLE:
        torch.manual_seed(params.SEED)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        random.seed(params.SEED)
        np.random.seed(params.SEED)
    start = time.time()
    data_module = DataModule(params)
    params.NUM_STATION = data_module.num_station
    params.LEN_TRAIN, params.LEN_VAL, params.LEN_TEST = data_module.get_len_loader()
    print('Number of station: ', str(params.NUM_STATION))
    params.NAME_RUN = params.EXP_NAME + '__' + strftime("%d/%m/%y") + '_' + strftime("%H:%M:%S")
    print("Run name: ", params.NAME_RUN)
    model = TrainingModule(params)
    params_dict = get_hyperparams_dict(params)
    checkpoint_callback = ModelCheckpoint(dirpath="checkpoints", save_top_k=2, monitor="val_loss_FM_power")

    # Logger
    if params.LOGGER:
        wandb.login()
        wandb.init(project=params.NAME_PROJECT, name=params.NAME_RUN, entity="alessio_v", config=params_dict)
        params.LOGGER = WandbLogger(log_model=False)
        # wandb_logger.watch(model, log_freq=100)  # log='gradients',

    trainer = pl.Trainer(
        max_epochs=params.EPOCHS,
        fast_dev_run=params.FAST_DEV_RUN,
        logger=params.LOGGER,
        progress_bar_refresh_rate=params.BAR_REFRESH_RATE,
        gpus=params.NUM_GPUS,
        gradient_clip_val=params.GRADIENT_CLIP,
        check_val_every_n_epoch=params.CHECK_VAL_EVERY_N_EPOCH,
        accumulate_grad_batches=params.ACCUMULATE_GRADIENT_BATCHES,
        limit_train_batches=params.LIMIT_TRAIN_BATCHES,
        limit_val_batches=params.LIMIT_VAL_BATCHES,
        limit_test_batches=params.LIMIT_TEST_BATCHES,
        auto_lr_find=params.AUTO_LR_FIND)#, callbacks=[checkpoint_callback]
        # precision=32
        # deterministic=True
    #)

    # LOAD_PATH = r'H:\Il mio Drive\PhD ICT\Code\Fase_2\Projects\test_architettura_2\test_6\checkpoints\epoch=51-step=83719.ckpt'  #BS=64, FS=2, EP=4
    # model = TrainingModule.load_from_checkpoint(LOAD_PATH, params=params)
    # trainer.test(model, data_module)
    # trainer.tune(model, data_module)
    trainer.fit(model, data_module)
    loss_dict = trainer.test(model, data_module)
    results.append(loss_dict[0])
    if params.LOGGER: wandb.finish()
    end = time.time()
    total_time = (end - start)/60
    print('Total time of training: ', str(total_time))

print_runs_results(params, RUNS, results)


  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..

  | Name            | Type            | Params
----------------------------------------------------
0 | GNN_Forecasting | GNN_Forecasting | 102 K 
1 | mse_f           | MSELoss         | 0     
----------------------------------------------------
102 K     Trainable params
0         Non-trainable params
102 K     Total params
0.409     Total estimated model params size (MB)


Number of station:  4
Run name:  GNN_F_PV4_24_24__29/04/22_23:45:52


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]