# ST-GAT Model

## Load libraries

In [None]:
import json
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import glob
import preprocessing as preprocessing
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device}")
print(torch.__version__)
from torch_geometric.data import Data

from torch.utils.tensorboard import SummaryWriter


#Lets start at src location
if os.path.exists("./src"):
    os.chdir("./src")


## Data preparation

In [None]:
import importlib
importlib.reload(preprocessing)

def prepare_data(config):

    print("Preparing data...")

    if config["USE_HOLIDAY_FEATURES"]:
        # holidays for each region
        holidays = pd.read_csv("../data/holidays.csv")

        # drop holiday region (all values look like NaN - all we have is Austrian regions or very german sounding Slovenian regions)
        holidays = holidays.drop(['region'], axis = 1)

        # remove NaNs
        holidays = holidays.drop_duplicates()

        # name holidays properly
        holidays.rename(columns = {'date': 'Date'}, inplace = True)

        holiday_markers = preprocessing.add_hours_to_holidays(holidays)
        holiday_markers["Date"] = pd.to_datetime(holiday_markers['Date']) 
        slovenian_holiday_markers = holiday_markers[holiday_markers["country"] == "Slovenia"]


    counters_df = pd.DataFrame()
    for fname in glob.glob(config["counter_files_path"] + "*.csv"):
        counter_data = pd.read_csv(fname)
        counter_data = preprocessing.fill_gaps(counter_data)
        #counter_data = preprocessing.mark_holidays(counter_data, holiday_markers)
        counter_data['Date'] = pd.to_datetime(counter_data['Date']) 
        counter_data.index = counter_data['Date']
        counter_data = counter_data.sort_index(ascending=False)
        # We don't need to work with all past data.
        # Select enough data points to extract N_GRAPHS with F_IN and F_OUT timepoints
        
        counter_data = counter_data.iloc[0:(config["F_IN"]+config["F_OUT"]+config["N_GRAPHS"]), :]
        counter_id = fname.split('/')[-1].split('.csv')[0]

        if counters_df.empty:
            counters_df = pd.DataFrame(counter_data[config['target_col']])
            counters_df.columns = [counter_id]
        else:
            columns = list(counters_df.columns) + [counter_id]
            counters_df = pd.concat([counters_df, counter_data[config['target_col']]], axis=1)
            counters_df.columns = columns 


    #Prepare edge_index matrix
    counters_aggregated = pd.read_csv(config['counters_nontemporal_aggregated'])
    edge_index, n_node, num_edges = preprocessing.construct_edge_index(counters_aggregated)

    #Prepare matrices X [N_GRAPHS, N_NODES, F_IN] and Y [N_GRAPHS, N_NODES, F_OUT] 
    graphs = []
    for i in range(1, config["N_GRAPHS"]+1):
        g = Data()
        g.__num_nodes__ = n_node
        g.edge_index = edge_index
        train_test_chunk = counters_df.iloc[(-i-(config['F_IN']+config['F_OUT'])):(-i),:]
        
        current_date = np.max(train_test_chunk.iloc[:config['F_IN'],:].index)

        X = train_test_chunk.iloc[:config['F_IN'],:].to_numpy().T
        Y = train_test_chunk.iloc[config['F_IN']:,:].to_numpy().T
        if config["USE_HOLIDAY_FEATURES"]:
            
            if len(slovenian_holiday_markers[slovenian_holiday_markers["Date"] == current_date].index) > 0:
                X = np.hstack((X, np.ones((len(X), 1))))
            else:
                X = np.hstack((X, np.zeros((len(X), 1))))


        g.x = torch.FloatTensor(X)
        g.y = torch.FloatTensor(Y)
        graphs += [g]

    splits = (0.6, 0.1, 0.3) # 60% Train, 10% Validation, 30% Test
    split_train, split_val, _ = splits
    index_train = int(np.floor(config["N_GRAPHS"]*split_train))
    index_val = int(index_train + np.floor(config["N_GRAPHS"]*split_val))
    train_g = graphs[:index_train]
    val_g = graphs[index_train:index_val]
    test_g = graphs[index_val:]

    print("Size of train data:", len(train_g))
    print("Size of validation data:", len(val_g))
    print("Size of test data:", len(test_g))

    return train_g, val_g, test_g

## Models

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv, GCNConv
class ST_GAT(torch.nn.Module):
    """
    Spatio-Temporal Graph Attention Network as presented in https://ieeexplore.ieee.org/document/8903252
    """
    def __init__(self, in_channels, out_channels, n_nodes, heads=8, dropout=0.0):
        """
        Initialize the ST-GAT model
        :param in_channels Number of input channels
        :param out_channels Number of output channels
        :param n_nodes Number of nodes in the graph
        :param heads Number of attention heads to use in graph
        :param dropout Dropout probability on output of Graph Attention Network
        """
        super(ST_GAT, self).__init__()
        self.n_pred = out_channels
        self.heads = heads
        self.dropout = dropout
        self.n_nodes = n_nodes

        self.n_preds = 9

        self.gat = GATConv(in_channels=in_channels, out_channels=in_channels,
                heads=heads, dropout=0, concat=False)

        if config["USE_LSTM"]:
            self.lstms = []
            for layer_index, layer_size in enumerate(config["LSTM_LAYER_SIZES"]):
                if layer_index == 0: input_size = self.n_nodes
                else: input_size = config["LSTM_LAYER_SIZES"][layer_index - 1]

                lstm = torch.nn.LSTM(input_size=input_size, hidden_size=layer_size, num_layers=1)
                for name, param in lstm.named_parameters():
                    if 'bias' in name:
                        torch.nn.init.constant_(param, 0.0)
                    elif 'weight' in name:
                        torch.nn.init.xavier_uniform_(param)
                self.lstms.append(lstm)

            # fully-connected neural network
            self.linear = torch.nn.Linear(config["LSTM_LAYER_SIZES"][-1], self.n_nodes*self.n_pred)
        else:
            self.grus = []
            for layer_index, layer_size in enumerate(config["GRU_LAYER_SIZES"]):
                if layer_index == 0: input_size = self.n_nodes
                else: input_size = config["GRU_LAYER_SIZES"][layer_index - 1]

                lstm = torch.nn.GRU(input_size=input_size, hidden_size=layer_size, num_layers=1)
                self.grus.append(lstm)

            # fully-connected neural network
            self.linear = torch.nn.Linear(config["GRU_LAYER_SIZES"][-1], self.n_nodes*self.n_pred)
        torch.nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, data, device):
        """
        Forward pass of the ST-GAT model
        :param data Data to make a pass on
        :param device Device to operate on
        """
        x, edge_index = data.x, data.edge_index
        # apply dropout
        if device == 'cpu':
            x = torch.FloatTensor(x)
        else:
            x = torch.cuda.FloatTensor(x)

        x = self.gat(x, edge_index)
        x = F.dropout(x, self.dropout, training=self.training)


            # RNN: 2 LSTM
        batch_size = data.num_graphs
        n_node = int(data.num_nodes/batch_size)
        x = torch.reshape(x, (batch_size, n_node, data.num_features))
        x = torch.movedim(x, 2, 0)
        if config["USE_LSTM"]:
            for lstm in self.lstms:
                x, _ = lstm(x)
        else:
            for gru in self.grus:
                x, _ = gru(x)


        x = torch.squeeze(x[-1, :, :])
        x = self.linear(x)

        s = x.shape
        x = torch.reshape(x, (s[0], self.n_nodes, self.n_pred))
        x = torch.reshape(x, (s[0]*self.n_nodes, self.n_pred))
        return x
    

class ST_GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, n_nodes, dropout=0.0):
        """
        Initialize the ST-GAT model
        :param in_channels Number of input channels
        :param out_channels Number of output channels
        :param n_nodes Number of nodes in the graph
        :param heads Number of attention heads to use in graph
        :param dropout Dropout probability on output of Graph Attention Network
        """
        super(ST_GCN, self).__init__()
        self.n_pred = out_channels
        self.dropout = dropout
        self.n_nodes = n_nodes

        self.n_preds = 9

        self.gcn = GCNConv(in_channels=in_channels, out_channels=in_channels, dropout=0, concat=False)

        if config["USE_LSTM"]:
            self.lstms = []
            for layer_index, layer_size in enumerate(config["LSTM_LAYER_SIZES"]):
                if layer_index == 0: input_size = self.n_nodes
                else: input_size = config["LSTM_LAYER_SIZES"][layer_index - 1]

                lstm = torch.nn.LSTM(input_size=input_size, hidden_size=layer_size, num_layers=1)
                for name, param in lstm.named_parameters():
                    if 'bias' in name:
                        torch.nn.init.constant_(param, 0.0)
                    elif 'weight' in name:
                        torch.nn.init.xavier_uniform_(param)
                self.lstms.append(lstm)

            # fully-connected neural network
            self.linear = torch.nn.Linear(config["LSTM_LAYER_SIZES"][-1], self.n_nodes*self.n_pred)
        else:
            self.grus = []
            for layer_index, layer_size in enumerate(config["GRU_LAYER_SIZES"]):
                if layer_index == 0: input_size = self.n_nodes
                else: input_size = config["GRU_LAYER_SIZES"][layer_index - 1]

                lstm = torch.nn.GRU(input_size=input_size, hidden_size=layer_size, num_layers=1)
                self.grus.append(lstm)

            # fully-connected neural network
            self.linear = torch.nn.Linear(config["GRU_LAYER_SIZES"][-1], self.n_nodes*self.n_pred)
        torch.nn.init.xavier_uniform_(self.linear.weight)

    def forward(self, data, device):
        """
        Forward pass of the ST-GAT model
        :param data Data to make a pass on
        :param device Device to operate on
        """
        x, edge_index = data.x, data.edge_index
        # apply dropout
        if device == 'cpu':
            x = torch.FloatTensor(x)
        else:
            x = torch.cuda.FloatTensor(x)

        x = self.gcn(x, edge_index)
        x = F.dropout(x, self.dropout, training=self.training)

        # RNN: 2 LSTM
        batch_size = data.num_graphs
        n_node = int(data.num_nodes/batch_size)
        x = torch.reshape(x, (batch_size, n_node, data.num_features))
        x = torch.movedim(x, 2, 0)
        if config["USE_LSTM"]:
            for lstm in self.lstms:
                x, _ = lstm(x)
        else:
            for gru in self.grus:
                x, _ = gru(x)


        x = torch.squeeze(x[-1, :, :])
        x = self.linear(x)

        s = x.shape
        x = torch.reshape(x, (s[0], self.n_nodes, self.n_pred))
        x = torch.reshape(x, (s[0]*self.n_nodes, self.n_pred))
        return x


## Train the model

In [None]:
import torch
import torch.optim as optim
from tqdm import tqdm
import time
import os
import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter


def model_train(train_dataloader, val_dataloader, config, device, save_test_results = False, test_dataloader = None):
    """
    Train the ST-GAT model. Evaluate on validation dataset as you go.
    :param train_dataloader Data loader of training dataset
    :param val_dataloader Dataloader of val dataset
    :param config configuration to use
    :param device Device to evaluate on
    """

    # Make the model. Each datapoint in the graph is 228x12: N x F (N = # nodes, F = time window)
    in_channels=config['F_IN']
    if config["USE_HOLIDAY_FEATURES"]: in_channels += 1
    if config["USE_GAT"]:
        model = ST_GAT(in_channels=in_channels, out_channels=config['F_OUT'], n_nodes=config['N_NODE'], dropout=config['DROPOUT'])
    else:
        model = ST_GCN(in_channels=in_channels, out_channels=config['F_OUT'], n_nodes=config['N_NODE'], dropout=config['DROPOUT'])

    optimizer = optim.Adam(model.parameters(), lr=config['INITIAL_LR'], weight_decay=config['WEIGHT_DECAY'])
    loss_fn = torch.nn.MSELoss

    model.to(device)

    # For every epoch, train the model on training dataset. Evaluate model on validation dataset
    for epoch in range(config['EPOCHS']):
        loss = train(model, device, train_dataloader, optimizer, loss_fn, epoch)
        print(f"Loss: {loss:.3f}")
        if epoch % 5 == 0:
            train_mae, train_rmse, train_mape, _, _ = eval(model, device, train_dataloader, 'Train')
            val_mae, val_rmse, val_mape, _, _ = eval(model, device, val_dataloader, 'Valid')
            if config["use_tensorboard"]:
                writer.add_scalar(f"MAE/train", train_mae, epoch)
                writer.add_scalar(f"RMSE/train", train_rmse, epoch)
                writer.add_scalar(f"MAPE/train", train_mape, epoch)
                writer.add_scalar(f"MAE/val", val_mae, epoch)
                writer.add_scalar(f"RMSE/val", val_rmse, epoch)
                writer.add_scalar(f"MAPE/val", val_mape, epoch)

    if config["use_tensorboard"]:
        writer.flush()
    # Save the model
    timestr = time.strftime("%m-%d-%H%M%S")
    os.mkdir(os.path.join(config["CHECKPOINT_DIR"], f"run_{timestr}"))
    torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "loss": loss,
            }, os.path.join(config["CHECKPOINT_DIR"], f"run_{timestr}/model.pt"))
    
    with open(os.path.join(config["CHECKPOINT_DIR"], f"run_{timestr}/config.json"), "w") as fp:
        json.dump(config, fp)

    if save_test_results:
        test_mae, test_rmse, test_mape, y_pred, y_truth = eval(model, device, test_dataloader, 'Test')
        results = {'MAE': test_mae.item(),
                    'RMSE': test_rmse.item(),
                    'MAPE': test_mape.item()}
        with open(os.path.join(config["CHECKPOINT_DIR"], f"run_{timestr}/results.json"), "w") as fp:
            json.dump(results, fp)
    return model

In [None]:
def z_score(x, mean, std):
    return (x - mean) / std
def un_z_score(x_normed, mean, std):
    return x_normed * std  + mean
def MAPE(v, v_):
    return torch.mean(torch.abs((v_ - v)) /(v + 1e-15) * 100)
def RMSE(v, v_):
    return torch.sqrt(torch.mean((v_ - v) ** 2))
def MAE(v, v_):
    return torch.mean(torch.abs(v_ - v))

In [None]:
@torch.no_grad()
def eval(model, device, dataloader, type=''):
    """
    Evaluation function to evaluate model on data
    :param model Model to evaluate
    :param device Device to evaluate on
    :param dataloader Data loader
    :param type Name of evaluation type, e.g. Train/Val/Test
    """
    model.eval()
    model.to(device)

    mae = 0
    rmse = 0
    mape = 0
    n = 0

    # Evaluate model on all data
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(batch, device)
            truth = batch.y.view(pred.shape)
            if i == 0:
                y_pred = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
                y_truth = torch.zeros(len(dataloader), pred.shape[0], pred.shape[1])
            #truth = un_z_score(truth, dataloader.dataset.mean, dataloader.dataset.std_dev)
            #pred = un_z_score(pred, dataloader.dataset.mean, dataloader.dataset.std_dev)
            y_pred[i, :pred.shape[0], :] = pred
            y_truth[i, :pred.shape[0], :] = truth
            rmse += RMSE(truth, pred)
            mae += MAE(truth, pred)
            mape += MAPE(truth, pred)
            n += 1
    rmse, mae, mape = rmse / n, mae / n, mape / n

    print(f'{type}, MAE: {mae}, RMSE: {rmse}, MAPE: {mape}')

    #get the average score for each metric in each batch
    return rmse, mae, mape, y_pred, y_truth

from torch.optim.lr_scheduler import StepLR 

def train(model, device, dataloader, optimizer, loss_fn, epoch):
    """
    Evaluation function to evaluate model on data
    :param model Model to evaluate
    :param device Device to evaluate on
    :param dataloader Data loader
    :param optimizer Optimizer to use
    :param loss_fn Loss function
    :param epoch Current epoch
    """
    
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    model.train()
    for _, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}")):
        batch = batch.to(device)
        optimizer.zero_grad()
        y_pred = torch.squeeze(model(batch, device))
        loss = loss_fn()(y_pred.float(), torch.squeeze(batch.y).float())
        if config["use_tensorboard"]:
            writer.add_scalar("Loss/train", loss, epoch)
        loss.backward()
        optimizer.step()

        # multiplicative decay
        scheduler.step()

    return loss

In [None]:
from torch_geometric.loader import DataLoader

# Constant config to use throughout
config = {
    'BATCH_SIZE': 50,
    'EPOCHS': 60,
    'WEIGHT_DECAY': 5e-5,
    'INITIAL_LR': 1e-1,
    'CHECKPOINT_DIR': '../runs',
    'DROPOUT': 0.2,
    "counter_files_path"                : "../data/counters_temporal_data_2023-03-03T09-24-06/",
    "counters_nontemporal_aggregated"   : "../data/counters_non_temporal_aggregated_data.csv",
    "USE_HOLIDAY_FEATURES"              : True,
    "N_GRAPHS"                          : 30*24,
    "F_IN"                              : 7*24,
    "F_OUT"                             : 7*24,
    "N_NODE"                            : 165,
    "target_col"                        : "Sum",
    "use_tensorboard"                   : False,
    "USE_GAT"                           : True, # if True use GAT, else use GCN
    "USE_LSTM"                          : False, # if True use LSTM, else use GRU
    "LSTM_LAYER_SIZES"                  : [128, 32],  
    "GRU_LAYER_SIZES"                  : [128, 32],    
}


# Make a tensorboard writer
if config["use_tensorboard"]:
    writer = SummaryWriter()

if not os.path.exists(config["CHECKPOINT_DIR"]):
    os.mkdir(config["CHECKPOINT_DIR"])

train_g, val_g, test_g = prepare_data(config)
train_dataloader = DataLoader(train_g, batch_size=config['BATCH_SIZE'], shuffle=False)
val_dataloader = DataLoader(val_g, batch_size=config['BATCH_SIZE'], shuffle=False)
test_dataloader = DataLoader(test_g, batch_size=config['BATCH_SIZE'], shuffle=False)

# Get gpu if you can
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device}")

# Configure and train model
model = model_train(train_dataloader, val_dataloader, config, device, True, test_dataloader)