In [None]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_gpu = torch.cuda.is_available()
print(torch.__version__)
print(device)

# Initialize for CPU & GPU

In [None]:
if use_gpu:
    !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric --no-index --find-links=file:///kaggle/input/torch-geometric
#     !pip install torch_geometric_temporal  
else:
    !pip install /kaggle/input/pytorch-geometric-cpu/torch_scatter-2.1.1pt20cpu-cp310-cp310-linux_x86_64.whl
    !pip install /kaggle/input/pytorch-geometric-cpu/torch_sparse-0.6.17pt20cpu-cp310-cp310-linux_x86_64.whl
    !pip install /kaggle/input/pytorch-geometric-cpu/torch_cluster-1.6.1pt20cpu-cp310-cp310-linux_x86_64.whl
    !pip install /kaggle/input/pytorch-geometric-cpu/torch_spline_conv-1.2.2pt20cpu-cp310-cp310-linux_x86_64.whl
    !pip install /kaggle/input/pytorch-geometric-cpu/torch_geometric-2.3.1-py3-none-any.whl
#     !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
#     !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
#     !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
#     !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
#     !pip install torch_spline_conv -f https://data.pyg.org/whl/torch-{torch.__version__}.html
#     !pip install git+https://github.com/pyg-team/pytorch_geometric.git
#     !pip install torch_geometric_temporal

# Import

In [None]:
import numpy as np
import pandas as pd
import os
import re
import json
import random
from scipy.special import perm
from itertools import combinations,chain
from typing import List, Union
from torch_geometric.data import Data
from sklearn.preprocessing import OneHotEncoder

import matplotlib.pyplot as plt
import pickle
import time
from torch import nn
from torch.nn import Linear as Lin
from torch.nn import ReLU, LeakyReLU
from torch.nn import Sequential as Seq
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.data import Batch
from torch import autograd
from torch_geometric.nn.models import InnerProductDecoder
from torch_geometric.utils import to_dense_adj

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, precision_recall_curve, roc_auc_score

from texttable import Texttable
from itertools import product

# DataLoader

In [None]:
Edge_Flag = List[Union[np.ndarray, None]]
Edge_Index = List[Union[np.ndarray, None]]
Edge_Attr = List[Union[np.ndarray, None]]

Node_Flag = List[Union[np.ndarray, None]]
Node_Index = List[Union[np.ndarray, None]]
Node_Attr = List[Union[np.ndarray, None]]



class GraphSignal(object):
    # dynamic node static node attr dynamic edge and edge attr
    def __init__(
        self,
        edge_flag: Edge_Flag,
        edge_index: Edge_Index,
        edge_attr: Edge_Attr,
        node_flag: Node_Flag,
        node_index: Node_Index,
        node_attr: Node_Attr,
        ts_list: List,
        path: str,
    ):
        
        self.raw_edge_flag = torch.LongTensor(edge_flag[:-1])
        self.raw_edge_index = torch.LongTensor(edge_index).T
        self.raw_edge_attr = edge_attr
        self.raw_node_attr = node_attr 
        self.node_flag = torch.LongTensor(node_flag[:-1])
        self.node_index = torch.LongTensor(node_index)
        
        self.ts_list = ts_list
        
        self.path = path
        
        self.node_attr = None
        self.edge_flag = None
        self.edge_index = None
        
        self.y = None
        
        self._set_snapshot_count()
        self._set_node_count()

    def _set_snapshot_count(self):
        self.snapshot_count = len(self.ts_list)
    
    def _set_node_count(self):
        self.node_count = self.raw_node_attr.shape[0]
    
    def encode_edge_attr(self, enocder):
        self.edge_attr_encoded = torch.FloatTensor(enocder.transform(self.raw_edge_attr))
        
    def encode_node_attr(self, enocder):
        self.node_attr_encoded = torch.FloatTensor(enocder.transform(self.raw_node_attr))
        
    def extend_node_attr(self):
        if self.edge_attr_encoded is None:
            print("Edge Attr Need to be Encoded!")
            raise
        node_index = self.node_index
        node_attr = self.node_attr_encoded.index_select(dim=0,index=self.node_index)
        node_flag = self.node_flag
        
        edge_index = self.raw_edge_index
        edge_attr = self.edge_attr_encoded
        edge_flag = self.raw_edge_flag
        
        node_index_split = torch.tensor_split(node_index, node_flag)
        node_attr_split = torch.tensor_split(node_attr, node_flag)
        edge_index_split = torch.tensor_split(edge_index, edge_flag, dim=1)
        edge_attr_split = torch.tensor_split(edge_attr, edge_flag)

        base = 0
        new_node_attr = []
        new_edge_flag = []
        new_edge_index = []
        
        for i_snapshot in range(self.snapshot_count):
            _node_index = node_index_split[i_snapshot]
            _node_attr = node_attr_split[i_snapshot]
            _edge_index = edge_index_split[i_snapshot]
            _edge_attr = edge_attr_split[i_snapshot]

            if _edge_index.shape[1] != _edge_attr.shape[0]:
                print(i_snapshot, edge_index.shape, _edge_attr.shape)
                raise
            if _edge_index.shape[1] > 0:
                index_dict = {}
                for i_edge in range(_edge_index.shape[1]):
                    index_tuple = tuple(_edge_index[:,i_edge].tolist())
                    if index_tuple in index_dict:
                        index_dict[index_tuple] += [i_edge]
                    else:
                        index_dict[index_tuple] = [i_edge]

                _new_edge_index = []
                _new_edge_attr = []
                for key in index_dict.keys():
                    _new_edge_index.append(key)
                    _new_edge_attr.append(torch.sum(_edge_attr.index_select(0, torch.LongTensor(index_dict[key])),dim=0).unsqueeze(0))

                _new_edge_index = torch.LongTensor(_new_edge_index).T
                _new_edge_attr = torch.cat(_new_edge_attr,dim=0)
                
                base += _new_edge_index.shape[1]
                new_edge_index.append(_new_edge_index)

#                 _source_attr = torch.zeros((_node_attr.shape[0], _new_edge_attr.shape[1])).to(device)
#                 _target_attr = torch.zeros((_node_attr.shape[0], _new_edge_attr.shape[1])).to(device)
#                 _source_attr.index_add_(0, _new_edge_index[0], _new_edge_attr)
#                 _target_attr.index_add_(0, _new_edge_index[1], _new_edge_attr)
#                 new_node_attr.append(torch.cat([_node_attr,_source_attr,_target_attr], dim=1))

                _node_attr_extend = torch.zeros((_node_attr.shape[0], _new_edge_attr.shape[1]))
                _node_attr_extend.index_add_(0, _new_edge_index[0], _new_edge_attr)
                _node_attr_extend.index_add_(0, _new_edge_index[1], _new_edge_attr)
                new_node_attr.append(torch.cat([_node_attr,_node_attr_extend], dim=1)) 
                
            new_edge_flag.append(base)
        
        self.node_attr = torch.cat(new_node_attr, dim=0)
        self.edge_flag = torch.LongTensor(new_edge_flag)
        self.edge_index = torch.cat(new_edge_index,dim=1)
    
    def remove_init_stop(self, threshold, period,):
        node_index = self.node_index
        node_attr = self.node_attr
        node_flag = self.node_flag

        edge_index = self.edge_index
        edge_flag = self.edge_flag
        
        raw_edge_index = self.raw_edge_index
        raw_edge_attr = self.raw_edge_attr
        raw_edge_flag = self.raw_edge_flag

        node_index_split = torch.tensor_split(node_index, node_flag)
        node_attr_split = torch.tensor_split(node_attr, node_flag)
        edge_index_split = torch.tensor_split(edge_index, edge_flag, dim=1)
        
        raw_edge_index_split = torch.tensor_split(raw_edge_index, raw_edge_flag, dim=1)
        raw_edge_attr_split = np.split(raw_edge_attr, raw_edge_flag)

        i_init = None
        i_stop = None
        for i_snapshot, node_num in enumerate(torch.diff(self.node_flag)):
            if node_num > threshold:
                i_init = i_snapshot+1
                break

        for i_snapshot, node_num in enumerate(torch.flip(torch.diff(self.node_flag),dims=[0])):
            if node_num > threshold:
                i_stop = self.node_flag.shape[0]-1-i_snapshot
                break
        
        new_node_attr = torch.cat(node_attr_split[i_init+1+period:i_stop-period],dim=0)
        new_node_index = torch.cat(node_index_split[i_init+1+period:i_stop-period],dim=0)
        new_edge_index = torch.cat(edge_index_split[i_init+1+period:i_stop-period],dim=1)
        new_raw_edge_index = torch.cat(raw_edge_index_split[i_init+1+period:i_stop-period], dim=1)
        new_raw_edge_attr = np.concatenate(raw_edge_attr_split[i_init+1+period:i_stop-period])

        new_node_flag = node_flag[i_init+period+1:i_stop-period-1]-node_flag[i_init+period]
        new_edge_flag = edge_flag[i_init+period+1:i_stop-period-1]-edge_flag[i_init+period]
        new_raw_edge_flag = raw_edge_flag[i_init+period+1:i_stop-period-1]-raw_edge_flag[i_init+period]
        
        self.node_attr = new_node_attr
        self.node_index = new_node_index
        self.edge_index = new_edge_index
        self.raw_edge_attr = new_raw_edge_attr
        self.raw_edge_index = new_raw_edge_index
        
        self.node_flag = new_node_flag
        self.edge_flag = new_edge_flag
        self.raw_edge_flag = new_raw_edge_flag
        
        self.ts_list = self.ts_list[i_init+1+period:i_stop-period]
        
        self._set_snapshot_count()
    
    def annotation2y(self, annotation, interval, overlap, offset= -30):
        ts_list = self.ts_list
        y = torch.zeros(self.snapshot_count, dtype=torch.long)
        for i_ts, ts in enumerate(ts_list):
            if ts < float(annotation[1])+offset and float(annotation[1])+offset <= ts+interval-overlap: 
                y[i_ts] = 1
        self.y = y
        self.type = annotation[0]
    
    def to(self,device):
        self.node_attr = self.node_attr.to(device)
        self.node_index = self.node_index.to(device)
        self.edge_index = self.edge_index.to(device)
    
    def get_adj_list(self):
        edge_index = self.edge_index
        edge_flag = self.edge_flag
        edge_index_split = torch.tensor_split(edge_index, edge_flag, dim=1)
        adj_list = [torch.clamp(to_dense_adj(_edge_index)[0], min=0, max=1) for _edge_index in edge_index_split]
        return adj_list

    def __getitem__(self, time_index: int):
        raise
#         edge_index = self._get_edge_index(time_index)
#         edge_attr = self._get_edge_attr(time_index)
#         node_index,node_attr = self._get_node_index_attr(time_index)
#         _timestamp = self._get_timestamp(time_index)

#         snapshot = Data(
#             edge_index=edge_index,
#             edge_attr=edge_attr,
#             node_index=node_index,
#             node_attr=node_attr,
#             timestamp=_timestamp
#         )
#         return snapshot

    def __next__(self):
        if self.t < self.snapshot_count:
            snapshot = self[self.t]
            self.t = self.t + 1
            return snapshot
        else:
            self.t = 0
            raise StopIteration

    def __iter__(self):
        self.t = 0
        return self
    
    def __len__(self):
        return self.snapshot_count


class GraphDatasetLoader(object):
    def __init__(self,input_path=""):
        self.input_path = input_path
        self._read_data()
    
    def _read_data(self):
        self._dataset = np.load(self.input_path)

    def get_dataset(self): # -> DynamicGraphTemporalSignal:
        dataset = GraphSignal(
            edge_flag = self._dataset['edge_flag'],
            edge_index = self._dataset['edge_index'],
            edge_attr = self._dataset['edge_attr'],
            node_flag = self._dataset['node_flag'],
            node_index = self._dataset['node_index'],
            node_attr = self._dataset['node_attr'],
            ts_list = self._dataset['timestamp'],
            path = self.input_path
        )
        return dataset

# Read Data

In [None]:
data_dir_0 = '/kaggle/input/dissertation-data'
data_dir_1_list = [
    '2021-09-24-umbrella-experiment-64run-fran',
    '2021-09-27-RaspVM-experiment-64run-env1',
    '2021-09-29-RaspVM-experiment-64run-env1'
]


signals = []
annotation = []

for data_dir_1 in data_dir_1_list:
    with open(os.path.join(data_dir_0, data_dir_1, "annotated.json")) as f:
        annotated_dict = json.load(f)

    for data_dir_2 in os.listdir(os.path.join(data_dir_0, data_dir_1)):
        if data_dir_2 == "annotated.json":
            continue
        r = re.compile(".*.npz")
        graph_files = list(filter(r.match, os.listdir(os.path.join(data_dir_0, data_dir_1, data_dir_2))))

        if len(graph_files) > 1:
            print("Multiple Graph Files!")
            raise
        if len(graph_files) == 0:
            print("Not Found Graph File!")
            continue

        dataloader = GraphDatasetLoader(os.path.join(data_dir_0, data_dir_1, data_dir_2, graph_files[0]))
        signal = dataloader.get_dataset()
        signals.append(signal)
        annotation.append(annotated_dict[data_dir_2])

signals_train, signals_val, annotation_train, annotation_val = train_test_split(signals, annotation, test_size=0.5, random_state=1365)

In [None]:
_interval = 60
_overlap = 30

node_num_list = []
for signal in signals_train:
    node_num_list += torch.diff(signal.node_flag).tolist()
    
threshold = np.median(node_num_list)/2
period = 3

print(f"Threshold = {threshold} Period = {period}")

node_attr_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
node_attr_encoder = node_attr_encoder.fit(np.concatenate([sample.raw_node_attr for sample in signals_train]))

edge_attr_encoder = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
edge_attr_encoder = edge_attr_encoder.fit(np.concatenate([sample.raw_edge_attr for sample in signals_train]))

for i_signal, (signal, annotation) in enumerate(zip(signals_train, annotation_train)):
    signal.encode_node_attr(node_attr_encoder)
    signal.encode_edge_attr(edge_attr_encoder)
    signal.extend_node_attr()
    signal.remove_init_stop(threshold, period)
    signal.node_attr = F.tanh(signal.node_attr)
    signal.annotation2y(annotation, _interval, _overlap)
    
for i_signal, (signal, annotation) in enumerate(zip(signals_val, annotation_val)):
    signal.encode_node_attr(node_attr_encoder)
    signal.encode_edge_attr(edge_attr_encoder)
    signal.extend_node_attr()
    signal.remove_init_stop(threshold, period)
    signal.node_attr = F.tanh(signal.node_attr)
    signal.annotation2y(annotation, _interval, _overlap)

# for signal in signals_train:
#     signal.to(device)

# for signal in signals_val:
#     signal.to(device)
    
    
IN_CHANNELS = signals_train[0].node_attr.shape[1]
# EDGE_CHANNELS = signals_train[0].edge_attr_encoded.shape[1]

In [None]:
plt.hist([signal.snapshot_count for signal in signals_train], range=(0,50), bins=50)
plt.show()
plt.hist([signal.snapshot_count for signal in signals_val], range=(0,50), bins=50)
plt.show()
plt.hist([signal.snapshot_count for signal in signals_test], range=(0,50), bins=50)
plt.show()

In [None]:
for signal in signals_train[:30]:
    print(signal.type,list(map(int,torch.diff(signal.node_flag))))

In [None]:
MIN_LEN = 10

i_signal = 0
while i_signal < len(signals_train):
    if signals_train[i_signal].snapshot_count < MIN_LEN:
        signals_train.pop(i_signal)
    else:
        i_signal += 1
plt.hist([signal.snapshot_count for signal in signals_train], range=(0,30), bins=30)
plt.show()

i_signal = 0
while i_signal < len(signals_val):
    if signals_val[i_signal].snapshot_count < MIN_LEN:
        signals_val.pop(i_signal)
    else:
        i_signal += 1
plt.hist([signal.snapshot_count for signal in signals_val], range=(0,30), bins=30)
plt.show()

import pickle

with open('signals_train.pkl','wb') as f:
    pickle.dump(signals_train, f)
with open('signals_val.pkl','wb') as f:
    pickle.dump(signals_val, f)


In [None]:
# import pickle

# with open("/kaggle/input/dissertation-data/signals_train.pkl", "rb") as f:
#     signals_train = pickle.load(f)
# with open("/kaggle/input/dissertation-data/signals_val.pkl", "rb") as f:
#     signals_val = pickle.load(f)
# with open("/kaggle/input/dissertation-data/signals_test.pkl", "rb") as f:
#     signals_test = pickle.load(f)