In [1]:
import json
import glob
import re
import math
import random

import re
import sys
import os
from pathlib import Path
import json

import numpy as np
from numpy import percentile
import pandas as pd

import time
from datetime import datetime
from pytz import utc, timezone

from scipy.stats import iqr
from scipy.stats import rankdata, iqr, trim_mean

from sklearn.metrics import mean_squared_error, precision_score, recall_score, roc_auc_score, f1_score
from sklearn.preprocessing import MinMaxScaler, StandardScaler

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter, Linear, Sequential, BatchNorm1d, ReLU
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
from torch_geometric.nn import GCNConv, GATConv, EdgeConv

from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader, random_split, Subset
from test import *

## **Contents**

0.   Installations
1.   [Why Graph Neural Networks](#whygnn)?
2.   [Data Pre-Processing](#data_preprocessing)
3.   [Utilities](#util)<br>
        
4.   [TimeDataset (Preparing to Forecast)](#TimeDataset)
5.   [Graph Attention-Based Forecasting](#graph_layer)
6.   [Graph Structure Learning + GDN](#gdn)
7.   [Main](#driver)





### Why Graph Neural Networks?

<a id = "whygnn"> </a>


1.   Given high-dimensional time series data (e.g., sensor data), how can we detect anomalous events ?, events such as system faults
and attacks? More challenging, how can we do this in a way that captures complex inter-sensor relationships, and detects and explains anomalies which deviate from these relationships?
2.   Capturing only the linear relationships is insufficient for complex, highly nonlinear relationships in many real-world settings. Data from these sensors can be related in complex, nonlinear ways. To learn representations for nonlinear high-dimensional time series and predict time series data, deep learning based time series methods have attracted interest in recent years.

3.  In recent years, graph neural networks (GNNs) have emerged as successful approaches for modelling complex patterns in graph-structured data. In
general, GNNs assume that the state of a node is influenced by the states of its neighbors. 

4. GNNs use the same model parameters to model the behavior of each node, and hence face limitations in representing very different behaviors of different sensors. Moreover, GNNs typically require the graph structure as an input, whereas the graph structure is initially unknown in many cases, and needs to be learned from data.

5. GNNs are highly scalable as they use one model, and the same set of parameters to get the embeddings for all nodes. 

### Data Pre-Processing(Example using WADI data set, depending on the problem, it can be modified.)

<a id = "data_preprocessing"> </a>


1. We can modify the paths to training and testing data. 
2. Preprocessing involves missing data imputation with the mean values, normalizing the train/test data with MinMax Scaler. 
3. Downsample the data by 10 units(pick 1 timestep that represent 10 timesteps.)

In [2]:

def norm(train, test):
    normalizer = MinMaxScaler(feature_range=(0, 1)).fit(train) # scale training data to [0,1] range
    train_ret = normalizer.transform(train)
    test_ret = normalizer.transform(test)

    return train_ret, test_ret


def downsample(data, labels, down_len):
    np_data = np.array(data)
    np_labels = np.array(labels)

    orig_len, col_num = np_data.shape
    down_time_len = orig_len // down_len
    np_data = np_data.transpose()
    d_data = np_data[:, :down_time_len*down_len].reshape(col_num, -1, down_len)
    d_data = np.median(d_data, axis=2).reshape(col_num, -1)

    d_labels = np_labels[:down_time_len*down_len].reshape(-1, down_len)
    # if exist anomalies, then this sample is abnormal
    d_labels = np.round(np.max(d_labels, axis=1))

    d_data = d_data.transpose()
    return d_data.tolist(), d_labels.tolist()


def pre_process(train_path, test_path, list_save_path, train_save_path, test_save_path):

    train = pd.read_csv(train_path, index_col=0)
    test = pd.read_csv(test_path, index_col=0)
    train = train.iloc[:, 3:]
    test = test.iloc[:, 3:]
    train = train.fillna(train.mean())
    test = test.fillna(test.mean())
    train = train.fillna(0)
    test = test.fillna(0)

    # trim column names
    train = train.rename(columns=lambda x: x.strip())
    test = test.rename(columns=lambda x: x.strip())

    train_labels = np.zeros(len(train))
    test_labels = test.attack


    test = test.drop(columns=['attack'])
    cols = train.columns
    train.columns = cols
    test.columns = cols

    x_train, x_test = norm(train.values, test.values)

    d_train_x, d_train_labels = downsample(x_train, train_labels, 10)
    d_test_x, d_test_labels = downsample(x_test, test_labels, 10)

    train_df = pd.DataFrame(d_train_x, columns = train.columns)
    test_df = pd.DataFrame(d_test_x, columns = test.columns)


    test_df['attack'] = d_test_labels
    train_df['attack'] = d_train_labels

    train_df = train_df.iloc[2160:]

    train_df.to_csv(train_save_path)
    test_df.to_csv(test_save_path)

    f = open(list_save_path, 'w')
    for col in train.columns:
        f.write(col+'\n')
    f.close()

train_path = "C:/Users/harsh/OneDrive/Desktop/notebooks/WADI_14days.csv"
test_path = 'C:/Users/harsh/OneDrive/Desktop/notebooks/WADI_attackdata_labelled.csv'
list_save_path = 'C:/Users/harsh/OneDrive/Desktop/notebooks/list.txt'

train_save_path = 'C:/Users/harsh/OneDrive/Desktop/notebooks/train.csv'
test_save_path = 'C:/Users/harsh/OneDrive/Desktop/notebooks/test.csv'


pre_process(train_path, test_path, list_save_path, train_save_path, test_save_path)

### Utilities
<a id = 'util'></a>

Functions include helpers (assigning data to a device, **ex**: cuda or cpu), calculating scores(F1 etc) and statistics (median, IQR etc)

In [3]:
def get_feature_map(feature_list_path):
    feature_file = open(feature_list_path, 'r')
    feature_list = []
    for ft in feature_file:
        feature_list.append(ft.strip())

    return feature_list

# graph is 'fully-connect'
def get_fc_graph_struc(feature_list_path):
    feature_file = open(feature_list_path, 'r')

    struc_map = {}
    feature_list = []
    for ft in feature_file:
        feature_list.append(ft.strip())

    for ft in feature_list:
        if ft not in struc_map:
            struc_map[ft] = []

        for other_ft in feature_list:
            if other_ft is not ft:
                struc_map[ft].append(other_ft)
    
    return struc_map


def construct_data(data, feature_map, labels=0):
    res = []

    for feature in feature_map:
        if feature in data.columns:
            res.append(data.loc[:, feature].values.tolist())
        else:
            print(feature, 'not exist in data')
    # append labels as last
    sample_n = len(res[0])

    if type(labels) == int:
        res.append([labels]*sample_n)
    elif len(labels) == sample_n:
        res.append(labels)

    return res

def build_loc_net(struc, all_features, feature_map=[]):

    index_feature_map = feature_map
    edge_indexes = [[],[]]
    for node_name, node_list in struc.items():
        if node_name not in all_features:
            continue

        if node_name not in index_feature_map:
            index_feature_map.append(node_name)
        
        p_index = index_feature_map.index(node_name)
        for child in node_list:
            if child not in all_features:
                continue

            if child not in index_feature_map:
                print(f'error: {child} not in index_feature_map')
                #index_feature_map.append(child)

            c_index = index_feature_map.index(child)
            edge_indexes[0].append(c_index)
            edge_indexes[1].append(p_index)

    return edge_indexes

def get_batch_edge_index(org_edge_index, batch_num, node_num):
    # org_edge_index:(2, edge_num)
    edge_index = org_edge_index.clone().detach()
    edge_num = org_edge_index.shape[1]
    batch_edge_index = edge_index.repeat(1,batch_num).contiguous()

    for i in range(batch_num):
        batch_edge_index[:, i*edge_num:(i+1)*edge_num] += i*node_num

    return batch_edge_index.long()


### **TimeDataset (Preparing to Forecast)**

<a id='TimeDataset'></a>

1.   Thus, at time t, define the model input $x(t) \in \mathbb{R}^{N \times W}$
based on a sliding window of size w over the historical time
series data (whether training or testing data).


>    <center>$x_{t} = [s^{(t - w)}, s^{(t - w + 1)} .... s^{(t - 1)}]$</center>

<center> The target output that the model needs to predict is the sensor data at the current time tick, i.e. $s^{(t)}$. </center>

In [4]:
class TimeDataset(Dataset):
    def __init__(self, raw_data, edge_index, mode='train', config = None):
        self.raw_data = raw_data
        self.config = config
        self.edge_index = edge_index
        self.mode = mode

        x_data = raw_data[:-1]
        labels = raw_data[-1]
        data = x_data
        # to tensor
        data = torch.tensor(data).double()
        labels = torch.tensor(labels).double()

        self.x, self.y, self.labels = self.process(data, labels)
    
    def __len__(self):
        return len(self.x)


    def process(self, data, labels):
        x_arr, y_arr = [], []
        labels_arr = []

        slide_win, slide_stride = [self.config[k] for k in ['slide_win', 'slide_stride']]
        is_train = self.mode == 'train'
        node_num, total_time_len = data.shape
        rang = range(slide_win, total_time_len, slide_stride) if is_train else range(slide_win, total_time_len)
        
        for i in rang:
            ft = data[:, i-slide_win:i]
            tar = data[:, i]
            x_arr.append(ft)
            y_arr.append(tar)
            labels_arr.append(labels[i])

        x = torch.stack(x_arr).contiguous()
        y = torch.stack(y_arr).contiguous()

        labels = torch.Tensor(labels_arr).contiguous()
        
        return x, y, labels

    def __getitem__(self, idx):
        feature = self.x[idx].double()
        y = self.y[idx].double()
        edge_index = self.edge_index.long()
        label = self.labels[idx].double()

        return feature, y, label, edge_index


### **Graph Attention-Based Forecasting**


<a id='graph_layer'></a>


1.   To capture the relationships between
sensors, a graph attention-based feature extractor is introduced to fuse a node’s information with its neighbors based on
the learned graph structure. Feature extractor incorporates the sensor
embedding vectors $v_i$
, which characterize the different behaviors of different types of sensors. To do this, compute
node i’s aggregated representation $z_i$ as follows:

<center>$z^{(t)}_{i} = ReLU(\alpha_{i, i}\textbf{W}x^{(t)}_i + \sum_{j \in N(i)} \alpha_{i, j}\textbf{W}x^{(t)}_j)$</center>


<center>where $x^{(t)}_i \in \mathbb{R}^{w}$ is node i's input feature
$N(i) = {j | A_{ji} > 0}$ is the set of neighbors of node i obtained from
the learned adjacency matrix A, $W \in \mathbb{R}^{d \times w}$ is a trainable
weight matrix which applies a shared linear transformation to every node, and the attention coefficients $\alpha_{i, j}$ are computed as:</center>

<center>$g^{(t)}_{i} = v_{i} \oplus \textbf{W}x^{(t)}_{i}$</center>

<center>$\pi(i, j) = LeakyReLU(a^{T} (g^{(t)}_{i} \oplus g^{(t)}_{j}))$</center>
<center>$\alpha(i, j) = SoftMax(\pi(i, j))$</center>




### Graph Structure Learning + GDN
<a id = "gdn"></a>

1.   A major goal of this framework is to learn the relationships
between sensors in the form of a graph structure. To do this,
a directed graph is used, whose nodes represent sensors, and whose edges represent dependency relationships
between them.

2. An edge from one sensor to another indicates
that the first sensor is used for modelling the behavior of the
second sensor. **A directed graph is used because the dependency patterns between sensors need not be symmetric.**

3. A flexible framework is applied either to:<br>
    3.1 The usual case where we have no prior information about the graph structure.<br>
    
    3.2 The case where we have some prior information about which edges are plausible (e.g. the sensor system may be divided into parts, where sensors in different parts have minimal interaction).<br>


4. This prior information can be flexibly represented as a set
of candidate relations $C_i$ for each sensor i, i.e. the sensors
it could be dependent on:

<center>{$C_i \subset \{1,3, 8, ... N\}$ \ ${i}$, no self loop.}</center>


5. In the case without prior information, the candidate relations
of sensor i is simply all sensors, other than itself.

6.  The output of our algorithm is a set of $T_{test}$  binary labels
indicating whether each test time tick is  = 1 an anomaly or not,
i.e. $a(t) \in \{0, 1\}$, where $a(t)$ indicates that time(t) is
anomalous.

7.  To select the dependencies of sensor i among these candidates, compute the similarity between node i’s embedding vector, and the embeddings of its candidates ${j \in C_{i}}$

That is, first compute $e_{ji}$, the normalized dot product between the embedding vectors of sensor i, and the candidate
relation $j \in C_{i}$
. Then select the top k such normalized
dot products: here **TopK** denotes the indices of top-k values among its input (i.e. the normalized dot products). **The value of k can be chosen by the user according to the desired sparsity level**. Next, a graph attention-based
model is defined which makes use of this learned adjacency matrix A.


In [5]:
class GraphLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, heads=1, concat=True, negative_slope=0.2, dropout=0, bias=True, inter_dim=-1,**kwargs):
        super(GraphLayer, self).__init__(aggr='add', **kwargs)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.node_dim = 0
        self.__alpha__ = None
        self.lin = Linear(in_channels, heads * out_channels, bias=False)

        self.att_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_j = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_i = Parameter(torch.Tensor(1, heads, out_channels))
        self.att_em_j = Parameter(torch.Tensor(1, heads, out_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.lin.weight)
        glorot(self.att_i)
        glorot(self.att_j)
        zeros(self.att_em_i)
        zeros(self.att_em_j)
        zeros(self.bias)

    def forward(self, x, edge_index, embedding, return_attention_weights=False):
        if torch.is_tensor(x):
            x = self.lin(x)
            x = (x, x)
        else:
            x = (self.lin(x[0]), self.lin(x[1]))

        edge_index, _ = remove_self_loops(edge_index)
        edge_index, _ = add_self_loops(edge_index, num_nodes=x[1].size(self.node_dim))
        out = self.propagate(edge_index, x=x, embedding=embedding, edges=edge_index, return_attention_weights=return_attention_weights)

        if self.concat:
            out = out.view(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out = out + self.bias

        if return_attention_weights:
            alpha, self.__alpha__ = self.__alpha__, None
            return out, (edge_index, alpha)
        else:
            return out

    def message(self, x_i, x_j, edge_index_i, size_i, embedding, edges, return_attention_weights):
        x_i = x_i.view(-1, self.heads, self.out_channels)
        x_j = x_j.view(-1, self.heads, self.out_channels)

        if embedding is not None:
            embedding_i, embedding_j = embedding[edge_index_i], embedding[edges[0]]
            embedding_i = embedding_i.unsqueeze(1).repeat(1,self.heads,1)
            embedding_j = embedding_j.unsqueeze(1).repeat(1,self.heads,1)

            key_i = torch.cat((x_i, embedding_i), dim=-1)
            key_j = torch.cat((x_j, embedding_j), dim=-1)

        cat_att_i = torch.cat((self.att_i, self.att_em_i), dim=-1)
        cat_att_j = torch.cat((self.att_j, self.att_em_j), dim=-1)
        alpha = (key_i * cat_att_i).sum(-1) + (key_j * cat_att_j).sum(-1)
        alpha = alpha.view(-1, self.heads, 1)
        alpha = F.leaky_relu(alpha, self.negative_slope)
        alpha = softmax(alpha, edge_index_i, num_nodes = size_i)

        if return_attention_weights:
            self.__alpha__ = alpha

        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        
        return x_j * alpha.view(-1, self.heads, 1)

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__, self.in_channels, self.out_channels, self.heads)

    

class GNNLayer(nn.Module):
    def __init__(self, in_channel, out_channel, inter_dim=0, heads=1, node_num=100):
        super(GNNLayer, self).__init__()
        self.gnn = GraphLayer(in_channel, out_channel, inter_dim=inter_dim, heads=heads, concat=False)
        self.bn = nn.BatchNorm1d(out_channel)
        self.relu = nn.ReLU()
        self.leaky_relu = nn.LeakyReLU()

    def forward(self, x, edge_index, embedding=None, node_num=0):
        out, (new_edge_index, att_weight) = self.gnn(x, edge_index, embedding, return_attention_weights=True)
        self.att_weight_1 = att_weight
        self.edge_index_1 = new_edge_index
        out = self.bn(out)
        return self.relu(out)

class OutLayer(nn.Module):
    def __init__(self, in_num, node_num, layer_num, inter_num = 512):
        super(OutLayer, self).__init__()
        layers = []
        for i in range(layer_num):
            # last layer, output shape:1
            if i==layer_num-1:
                if layer_num==1:
                    layers.append(nn.Linear(in_num , 1))
                else:
                    layers.append(nn.Linear(inter_num , 1))
            else:
                layer_in_num = in_num if i == 0 else inter_num
                layers.append(nn.Linear( layer_in_num, inter_num ))
                layers.append(nn.BatchNorm1d(inter_num))
                layers.append(nn.ReLU())

        self.mlp = nn.ModuleList(layers)

    def forward(self, x):
        out = x
        for mod in self.mlp:
            if isinstance(mod, nn.BatchNorm1d):
                out = out.permute(0,2,1)
                out = mod(out)
                out = out.permute(0,2,1)
            else:
                out = mod(out)

        return out




class GDN(nn.Module):
    def __init__(self, edge_index_sets, node_num, dim=64, out_layer_inter_dim=256, input_dim=10, out_layer_num=1, topk=20):
        super(GDN, self).__init__()
        self.edge_index_sets = edge_index_sets
        device = 'cuda'
        edge_index = edge_index_sets[0]
        embed_dim = dim
        self.embedding = nn.Embedding(node_num, embed_dim)
        self.bn_outlayer_in = nn.BatchNorm1d(embed_dim)
        edge_set_num = len(edge_index_sets)
        
        current_gnn_layers = []
        for _ in range(edge_set_num):
            gnn_layer = GNNLayer(input_dim, dim, inter_dim=dim+embed_dim, heads=1)
            current_gnn_layers.append(gnn_layer)
        
        
        self.gnn_layers = nn.ModuleList(current_gnn_layers)

        self.node_embedding = None
        self.topk = topk
        self.learned_graph = None
        self.out_layer = OutLayer(dim*edge_set_num, node_num, out_layer_num, inter_num = out_layer_inter_dim)
        self.cache_edge_index_sets = [None] * edge_set_num
        self.cache_embed_index = None
        self.dp = nn.Dropout(0.2)
        self.init_params()
    
    def init_params(self):
        nn.init.kaiming_uniform_(self.embedding.weight, a=math.sqrt(5))

    def forward(self, data, org_edge_index):
        x = data.clone().detach()
        edge_index_sets = self.edge_index_sets
        device = data.device
        batch_num, node_num, all_feature = x.shape
        x = x.view(-1, all_feature).contiguous()
        gcn_outs = []
        for i, edge_index in enumerate(edge_index_sets):
            edge_num = edge_index.shape[1]
            cache_edge_index = self.cache_edge_index_sets[i]

            if cache_edge_index is None or cache_edge_index.shape[1] != edge_num*batch_num:
                self.cache_edge_index_sets[i] = get_batch_edge_index(edge_index, batch_num, node_num).to(device)
            
            batch_edge_index = self.cache_edge_index_sets[i]
            all_embeddings = self.embedding(torch.arange(node_num).to(device))
            weights_arr = all_embeddings.detach().clone()
            all_embeddings = all_embeddings.repeat(batch_num, 1)
            weights = weights_arr.view(node_num, -1)
            cos_ji_mat = torch.matmul(weights, weights.T)
            normed_mat = torch.matmul(weights.norm(dim=-1).view(-1,1), weights.norm(dim=-1).view(1,-1))
            cos_ji_mat = cos_ji_mat / normed_mat
            dim = weights.shape[-1]
            topk_num = self.topk
            topk_indices_ji = torch.topk(cos_ji_mat, topk_num, dim=-1)[1]
            self.learned_graph = topk_indices_ji
            gated_i = torch.arange(0, node_num).T.unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0)
            gated_j = topk_indices_ji.flatten().unsqueeze(0)
            gated_edge_index = torch.cat((gated_j, gated_i), dim=0)
            batch_gated_edge_index = get_batch_edge_index(gated_edge_index, batch_num, node_num).to(device)
            gcn_out = self.gnn_layers[i](x, batch_gated_edge_index, node_num=node_num*batch_num, embedding=all_embeddings)
            gcn_outs.append(gcn_out)

        x = torch.cat(gcn_outs, dim=1)
        x = x.view(batch_num, node_num, -1)
        
        indexes = torch.arange(0,node_num).to(device)
        out = torch.mul(x, self.embedding(indexes))
        
        out = out.permute(0,2,1)
        out = F.relu(self.bn_outlayer_in(out))
        out = out.permute(0,2,1)

        out = self.dp(out)
        out = self.out_layer(out)
        out = out.view(-1, node_num)
        return out

In [6]:

def loss_func(y_pred, y_true):
    loss = F.mse_loss(y_pred, y_true, reduction='mean')
    return loss

def train(model, config={},  train_dataloader=None, val_dataloader=None, feature_map={}, test_dataloader=None, test_dataset=None, dataset_name='swat', train_dataset=None):
    seed = config['seed']
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=config['decay'])
    train_loss_list = []
    device = 'cuda'
    acu_loss = 0

    i = 0
    epoch = config['epoch']
    early_stop_win = 15

    model.train()

    dataloader = train_dataloader

    for i_epoch in range(epoch):
        acu_loss = 0
        model.train()
        for x, labels, attack_labels, edge_index in dataloader:
            _start = time.time()
            x, labels, edge_index = [item.float().to(device) for item in [x, labels, edge_index]]
            optimizer.zero_grad()
            out = model(x, edge_index).float().to(device)
            loss = loss_func(out, labels)            
            loss.backward()
            optimizer.step()
            train_loss_list.append(loss.item())
            acu_loss += loss.item()
            i += 1

        print('epoch ({} / {}) (Loss:{:.8f}, ACU_loss:{:.8f})'.format(i_epoch, epoch, acu_loss/len(dataloader), acu_loss), flush=True)

    return train_loss_list


### Main Function
<a id = "driver"></a>


In [7]:
class Main(object):
    def __init__(self, train_config, env_config, debug=False):
        self.train_config = train_config
        self.env_config = env_config
        self.datestr = None
        train_orig = pd.read_csv('C:/Users/harsh/OneDrive/Desktop/notebooks/train.csv', sep=',', index_col=0)
        test_orig = pd.read_csv('C:/Users/harsh/OneDrive/Desktop/notebooks/test.csv', sep=',', index_col=0)
        feature_list_path = 'C:/Users/harsh/OneDrive/Desktop/notebooks/list.txt'
       
        train, test = train_orig, test_orig

        if 'attack' in train.columns:
            train = train.drop(columns=['attack'])

        feature_map = get_feature_map(feature_list_path)
        fc_struc = get_fc_graph_struc(feature_list_path)
        self.device = 'cuda'

        fc_edge_index = build_loc_net(fc_struc, list(train.columns), feature_map=feature_map)
        fc_edge_index = torch.tensor(fc_edge_index, dtype = torch.long)

        self.feature_map = feature_map

        train_dataset_indata = construct_data(train, feature_map, labels=0)
        test_dataset_indata = construct_data(test, feature_map, labels=test.attack.tolist())


        cfg = {
            'slide_win': train_config['slide_win'],
            'slide_stride': train_config['slide_stride'],
        }

        train_dataset = TimeDataset(train_dataset_indata, fc_edge_index, mode='train', config=cfg)
        test_dataset = TimeDataset(test_dataset_indata, fc_edge_index, mode='test', config=cfg)


        train_dataloader, val_dataloader = self.get_loaders(train_dataset, train_config['seed'], train_config['batch'], val_ratio = train_config['val_ratio'])

        self.train_dataset = train_dataset
        self.test_dataset = test_dataset


        self.train_dataloader = train_dataloader
        self.val_dataloader = val_dataloader
        self.test_dataloader = DataLoader(test_dataset, batch_size=train_config['batch'],
                            shuffle=False, num_workers=0)


        edge_index_sets = []
        edge_index_sets.append(fc_edge_index)

        self.model = GDN(edge_index_sets, len(feature_map), 
                dim=train_config['dim'], 
                input_dim=train_config['slide_win'],
                out_layer_num=train_config['out_layer_num'],
                out_layer_inter_dim=train_config['out_layer_inter_dim'],
                topk=train_config['topk']
            ).to(self.device)



    def run(self):
        self.train_log = train(self.model, 
                config = train_config,
                train_dataloader=self.train_dataloader,
                val_dataloader=self.val_dataloader, 
                feature_map=self.feature_map,
                test_dataloader=self.test_dataloader,
                test_dataset=self.test_dataset,
                train_dataset=self.train_dataset,
            )
        
    def get_loaders(self, train_dataset, seed, batch, val_ratio=0.1):
        dataset_len = int(len(train_dataset))
        train_use_len = int(dataset_len * (1 - val_ratio))
        val_use_len = int(dataset_len * val_ratio)
        val_start_index = random.randrange(train_use_len)
        indices = torch.arange(dataset_len)

        train_sub_indices = torch.cat([indices[:val_start_index], indices[val_start_index+val_use_len:]])
        train_subset = Subset(train_dataset, train_sub_indices)

        val_sub_indices = indices[val_start_index:val_start_index+val_use_len]
        val_subset = Subset(train_dataset, val_sub_indices)


        train_dataloader = DataLoader(train_subset, batch_size=batch,
                                shuffle=True)

        val_dataloader = DataLoader(val_subset, batch_size=batch,
                                shuffle=False)

        return train_dataloader, val_dataloader



if __name__ == "__main__":
    batch = 32
    epoch = 10
    slide_win = 5
    dim = 64
    slide_stride = 1
    device = 'cuda'
    random_seed = 5
    out_layer_num = 1
    out_layer_inter_dim = 128
    decay = 0
    val_ratio = 0.2
    topk = 5
    report = 'best'
    load_model_path = ''

    train_config = {
        'batch': batch,
        'epoch': epoch,
        'slide_win': slide_win,
        'dim': dim,
        'slide_stride': slide_stride,
        'seed': random_seed,
        'out_layer_num': out_layer_num,
        'out_layer_inter_dim': out_layer_inter_dim,
        'decay': decay,
        'val_ratio': val_ratio,
        'topk': topk,
    }

    env_config={
        'report': report,
        'device': device,
    }
    

    main = Main(train_config, env_config, debug=False)
    main.run()


  gated_i = torch.arange(0, node_num).T.unsqueeze(1).repeat(1, topk_num).flatten().to(device).unsqueeze(0)


epoch (0 / 10) (Loss:0.01976739, ACU_loss:37.71618425)
epoch (1 / 10) (Loss:0.00992056, ACU_loss:18.92842144)
epoch (2 / 10) (Loss:0.00906771, ACU_loss:17.30118730)
epoch (3 / 10) (Loss:0.00857509, ACU_loss:16.36126775)
epoch (4 / 10) (Loss:0.00865306, ACU_loss:16.51003857)
epoch (5 / 10) (Loss:0.00865328, ACU_loss:16.51045281)
epoch (6 / 10) (Loss:0.00864160, ACU_loss:16.48817449)
epoch (7 / 10) (Loss:0.00830218, ACU_loss:15.84056009)
epoch (8 / 10) (Loss:0.00815437, ACU_loss:15.55853944)
epoch (9 / 10) (Loss:0.00805597, ACU_loss:15.37078172)
