In [1]:
import torch

# Initialize for CPU

In [2]:
!pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
!pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install git+https://github.com/pyg-team/pytorch_geometric.git
!pip install torch_geometric_temporal

[0mLooking in links: https://data.pyg.org/whl/torch-2.0.0+cpu.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_scatter-2.1.1%2Bpt20cpu-cp310-cp310-linux_x86_64.whl (504 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m504.1/504.1 kB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.1+pt20cpu
[0mLooking in links: https://data.pyg.org/whl/torch-2.0.0+cpu.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_sparse-0.6.17%2Bpt20cpu-cp310-cp310-linux_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.17+pt20cpu
[0mLooking in links: https://data.pyg.org/whl/torch-2.0.0+cpu.html
Collecting torch-cluster
  Downloading htt

# Initialize for GPU

In [3]:
# !pip uninstall torch-scatter torch-sparse torch-geometric torch-cluster  --y
# !pip install torch-scatter -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install torch-cluster -f https://data.pyg.org/whl/torch-{torch.__version__}.html
# !pip install git+https://github.com/pyg-team/pytorch_geometric.git
# !pip install torch_geometric_temporal

In [4]:
import numpy as np
import pandas as pd
import torch
from scipy.special import perm
from itertools import combinations,chain
from typing import List, Union
from torch_geometric.data import Data

# DataLoader

In [5]:
Edge_Flag = List[Union[np.ndarray, None]]
Edge_Indices = List[Union[np.ndarray, None]]
Edge_Attr = List[Union[np.ndarray, None]]

Node_Flag = List[Union[np.ndarray, None]]
Node_Indices = List[Union[np.ndarray, None]]
Node_Attr = List[Union[np.ndarray, None]]
Additional_Attrs = List[np.ndarray]



class GraphSignal(object):
    # dynamic node static node attr dynamic edge and edge attr
    def __init__(
        self,
        edge_flag: Edge_Flag,
        edge_indices: Edge_Indices,
        edge_attr: Edge_Attr,
        node_flag: Node_Flag,
        node_indices: Node_Indices,
        node_attr: Node_Attr,
    ):
        
        self.edge_flag = edge_flag 
        self.edge_indices = edge_indices
        self.edge_attr = edge_attr
        self.node_flag = node_flag
        self.node_indices = node_indices
        self.node_attr = node_attr
        self._set_snapshot_count()

    def _set_snapshot_count(self):
        self.snapshot_count = len(self.edge_flag)

    def _get_edge_index(self, time_index: int):
        if time_index == 0:
            _start = 0
        else:
            _start = self.edge_flag[time_index-1]
        _end = self.edge_flag[time_index]
        _edge_index = self.edge_indices[_start:_end]
        return torch.LongTensor(np.array(_edge_index).T)

    # def _get_edge_weight(self, time_index: int):
    #     if self.edge_weights[time_index] is None:
    #         return self.edge_weights[time_index]
    #     else:
    #         return torch.FloatTensor(self.edge_weights[time_index])

    def _get_edge_attr(self, time_index: int):
        if time_index == 0:
            _start = 0
        else:
            _start = self.edge_flag[time_index-1]
        _end = self.edge_flag[time_index]
        _edge_attr = self.edge_attr[_start:_end]
        return torch.FloatTensor(np.array(_edge_attr))
    
    def _get_node_index_attr(self, time_index: int):
        if time_index == 0:
            _start = 0
        else:
            _start = self.node_flag[time_index-1]
        _end = self.node_flag[time_index]
        _node_index = self.node_indices[_start:_end]
        _node_attr = self.node_attr[_node_index]
        return torch.LongTensor(np.array(_node_index)),torch.FloatTensor(np.array(_node_attr))


    def __getitem__(self, time_index: int):
        edge_index = self._get_edge_index(time_index)
        edge_attr = self._get_edge_attr(time_index)
        node_index,node_attr = self._get_node_index_attr(time_index)

        snapshot = Data(
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_index=node_index,
            node_attr=node_attr,
        )
        return snapshot

    def __next__(self):
        if self.t < self.snapshot_count:
            snapshot = self[self.t]
            self.t = self.t + 1
            return snapshot
        else:
            self.t = 0
            raise StopIteration

    def __iter__(self):
        self.t = 0
        return self
    
    def __len__(self):
        return self.snapshot_count


class GraphDatasetLoader(object):
    def __init__(self,input_path=""):
        self.input_path = input_path
        self._read_data()
    
    def _read_data(self):
        self._dataset = np.load(self.input_path)

    def get_dataset(self): # -> DynamicGraphTemporalSignal:
        dataset = GraphSignal(
            edge_flag = self._dataset['edge_flag_array'],
            edge_indices = self._dataset['edge_index_array'],
            edge_attr = self._dataset['edge_attr_array'],
            node_flag = self._dataset['node_flag_array'],
            node_indices = self._dataset['node_index_array'],
            node_attr = self._dataset['node_attr_array'],
        )
        return dataset

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_gpu = torch.cuda.is_available()
dataloader = GraphDatasetLoader("/kaggle/input/dissertation-data/test_60_30.npz")
train_set = dataloader.get_dataset()

# Model

In [7]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import pickle
import time

import torch
from torch import nn
from torch.nn import Linear as Lin
from torch.nn import ReLU
from torch.nn import Sequential as Seq
import torch.nn.functional as F
from torch_geometric.nn import NNConv
from torch_geometric_temporal import GConvGRU
from torch import autograd

In [8]:
class MultiNNConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, edge_channels, gcn_hidden_nums, edge_hidden_nums, lin_hidden_nums):
        super().__init__()
        
        def _create_edge_nn(edge_out_channels):
            edge_nn = Seq()
            pre_edge_h_num = edge_channels
            for edge_h_num in edge_hidden_nums:
                edge_nn.append(Lin(pre_edge_h_num,edge_h_num))
                edge_nn.append(ReLU())
                pre_edge_h_num = edge_h_num
            edge_nn.append(Lin(pre_edge_h_num,edge_out_channels))
            return edge_nn
        
        self.gcn_layers = nn.ModuleList()
        pre_h_num = in_channels
        for h_num in gcn_hidden_nums:
            edge_nn = _create_edge_nn(pre_h_num*h_num)
            self.gcn_layers.append(NNConv(pre_h_num, h_num, edge_nn, aggr='mean'))
            pre_h_num = h_num

        self.lin_net = Seq()
        for h_num in lin_hidden_nums[:-1]:
            self.lin_net.append(Lin(pre_h_num,h_num))
            pre_h_num = h_num
        self.lin_net.append(ReLU())
        self.lin_net.append(Lin(pre_h_num,out_channels))

    def forward(self, x, edge_index, edge_attr):
        out = x
        for conv in self.gcn_layers:
            out = conv(
                x=out,
                edge_index=edge_index,
                edge_attr=edge_attr,
            )
        out = self.lin_net(out)
        return out



class NNConvGRU(torch.nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        edge_channels: int,
        gcn_hidden_nums: List,
        edge_hidden_nums: List,
        lin_hidden_nums: List,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(NNConvGRU, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_channels = edge_channels
        self.gcn_hidden_nums = gcn_hidden_nums
        self.edge_hidden_nums = edge_hidden_nums
        self.lin_hidden_nums = lin_hidden_nums
            
        self.normalization = normalization
        self.bias = bias
        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_x_z = MultiNNConv(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )

        self.conv_h_z = MultiNNConv(
            in_channels = self.out_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )
        
    def _create_reset_gate_parameters_and_layers(self):

        self.conv_x_r = MultiNNConv(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )

        self.conv_h_r = MultiNNConv(
            in_channels = self.out_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_x_h = MultiNNConv(
            in_channels = self.in_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )

        self.conv_h_h = MultiNNConv(
            in_channels = self.out_channels,
            out_channels = self.out_channels,
            edge_channels = self.edge_channels,
            gcn_hidden_nums = self.gcn_hidden_nums,
            edge_hidden_nums = self.edge_hidden_nums,
            lin_hidden_nums = self.lin_hidden_nums,
        )

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_attr, H):
        Z = self.conv_x_z(X, edge_index, edge_attr)
        Z = Z + self.conv_h_z(H, edge_index, edge_attr)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_attr, H):
        R = self.conv_x_r(X, edge_index, edge_attr)
        R = R + self.conv_h_r(H, edge_index, edge_attr)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_attr, H, R):
        H_tilde = self.conv_x_h(X, edge_index, edge_attr)
        H_tilde = H_tilde + self.conv_h_h(H * R, edge_index, edge_attr)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
            self,
            X: torch.FloatTensor,
            edge_index: torch.LongTensor,
            edge_attr: torch.FloatTensor,
            H: torch.FloatTensor = None,
        ) -> torch.FloatTensor:
            H = self._set_hidden_state(X, H)
            Z = self._calculate_update_gate(X, edge_index, edge_attr, H)
            R = self._calculate_reset_gate(X, edge_index, edge_attr, H)
            H_tilde = self._calculate_candidate_state(X, edge_index, edge_attr, H, R)
            H = self._calculate_hidden_state(Z, H, H_tilde)
            return H

In [9]:
class TGAE(torch.nn.Module): # Not Heterogeneous
    def __init__(
        self, in_channels, out_channels, edge_channels, 
        embedding_hidden_nums, gnn_out_channels, deciding_hidden_nums,
        gru_gcn_hidden_nums, gru_edge_hidden_nums, gru_lin_hidden_nums):
        super(TGAE, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_channels = edge_channels
        self.embedding_hidden_nums = embedding_hidden_nums
        self.gnn_out_channels = gnn_out_channels
        self.deciding_hidden_nums = deciding_hidden_nums
        self.gru_gcn_hidden_nums = gru_gcn_hidden_nums
        self.gru_edge_hidden_nums = gru_edge_hidden_nums
        self.gru_lin_hidden_nums = gru_lin_hidden_nums
        
        # Encoder
        layers = []
        pre_h_num = in_channels
        for h_num in embedding_hidden_nums[:-1]:
#             layers.append(torch.nn.Dropout(p=0.2))
            layers.append(Lin(pre_h_num,h_num))
            layers.append(torch.nn.LeakyReLU())
            pre_h_num = h_num
        layers.append(Lin(pre_h_num,embedding_hidden_nums[-1]))
        self.encoder_embedding_net = Seq(*layers)
        
        self.encoder_gru = NNConvGRU(
            in_channels=embedding_hidden_nums[-1],
            out_channels=gnn_out_channels,
            edge_channels=edge_channels,
            gcn_hidden_nums=gru_gcn_hidden_nums,
            edge_hidden_nums=gru_edge_hidden_nums,
            lin_hidden_nums=gru_lin_hidden_nums,
        )

        layers = []
        pre_h_num = gnn_out_channels
        for h_num in deciding_hidden_nums:
#             layers.append(torch.nn.Dropout(p=0.2))
            layers.append(Lin(pre_h_num,h_num))
            layers.append(torch.nn.LeakyReLU())
            pre_h_num = h_num
        layers.append(Lin(pre_h_num,out_channels))
        self.encoder_deciding_net = Seq(*layers)
        
        # Decoder
        layers = []
        pre_h_num = out_channels
        for h_num in deciding_hidden_nums[::-1]:
            layers.append(torch.nn.Dropout(p=0.2))
            layers.append(Lin(pre_h_num,h_num))
            layers.append(torch.nn.LeakyReLU())
            pre_h_num = h_num
        layers.append(Lin(pre_h_num,gnn_out_channels))
        self.decoder_deciding_net = Seq(*layers)
        
        self.decoder_gru = NNConvGRU(
            in_channels=gnn_out_channels,
            out_channels=embedding_hidden_nums[-1],
            edge_channels=edge_channels,
            gcn_hidden_nums=gru_gcn_hidden_nums,
            edge_hidden_nums=gru_edge_hidden_nums,
            lin_hidden_nums=gru_lin_hidden_nums,
        )
        
        layers = []
        pre_h_num = embedding_hidden_nums[-1]
        for h_num in embedding_hidden_nums[:-1][::-1]:
            layers.append(torch.nn.Dropout(p=0.2))
            layers.append(Lin(pre_h_num,h_num))
            layers.append(torch.nn.LeakyReLU())
            pre_h_num = h_num
        layers.append(Lin(pre_h_num,in_channels))
        self.decoder_embedding_net = Seq(*layers)
        

    def forward(self, x, edge_index, edge_attr, h_encoder=None, h_decoder=None):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """
        # Encoder
        out = self.encoder_embedding_net(x)
        
        # GNN layer
        h_encoder = self.encoder_gru(out, edge_index, edge_attr, h_encoder) 
        
        out = self.encoder_deciding_net(h_encoder)
        
        out = self.decoder_deciding_net(out)
        
        h_decoder = self.decoder_gru(out, edge_index, edge_attr, h_decoder)
        
        out = self.decoder_embedding_net(h_decoder)

        return out, h_encoder, h_decoder

In [10]:
model = TGAE(
    in_channels=5, 
    out_channels=5, 
    edge_channels=57, 
    embedding_hidden_nums=[4,4],
    gnn_out_channels=8,
    deciding_hidden_nums=[4,4],
    gru_gcn_hidden_nums=[16,16],
    gru_edge_hidden_nums=[32],
    gru_lin_hidden_nums=[64,64],
)

loss_f = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 2e-4, weight_decay=1e-5)
print(model)

TGAE(
  (encoder_embedding_net): Sequential(
    (0): Linear(in_features=5, out_features=4, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=4, out_features=4, bias=True)
  )
  (encoder_gru): NNConvGRU(
    (conv_x_z): MultiNNConv(
      (gcn_layers): ModuleList(
        (0): NNConv(4, 16, aggr=mean, nn=Sequential(
          (0): Linear(in_features=57, out_features=32, bias=True)
          (1): ReLU()
          (2): Linear(in_features=32, out_features=64, bias=True)
        ))
        (1): NNConv(16, 16, aggr=mean, nn=Sequential(
          (0): Linear(in_features=57, out_features=32, bias=True)
          (1): ReLU()
          (2): Linear(in_features=32, out_features=256, bias=True)
        ))
      )
      (lin_net): Sequential(
        (0): Linear(in_features=16, out_features=64, bias=True)
        (1): ReLU()
        (2): Linear(in_features=64, out_features=8, bias=True)
      )
    )
    (conv_h_z): MultiNNConv(
      (gcn_layers): ModuleList(
        (

# Training

In [11]:
def create_hidden_global(num_nodes, out_channels):
    hidden_global = np.zeros([num_nodes,out_channels])
    return hidden_global

def create_cell_global(num_nodes, out_channels):
    cell_global = torch.zeros(num_nodes,out_channels)
    return cell_global

def select_hidden_local(hidden_global, index):
    h = hidden_global[index]
    return h

def select_cell_local(cell_global, index):
    c = cell_global[index]
    return c

# TODO: Aggregation of hidden and cell
def update_hidden_gobal(hidden_global, h, index):
    hidden_global[index] = h.detach().cpu().numpy()
    # for key,value in mapping.items():
    #     hidden_global[value] = h[key] 

def update_cell_gobal(cell_global, c, index):
    cell_global[index] = c.detach().cpu().numpy()
    # for key,value in mapping.items():
    #     cell_global[value] = c[key] 

In [12]:
print(train_set.node_attr.shape)
for i,snapshot in enumerate(train_set,start=1):
    print(snapshot.node_attr.shape[0])
    print(snapshot.edge_index)


(4839, 5)
4
tensor([[1, 2, 3],
        [0, 0, 0]])
4
tensor([[1, 2, 3],
        [0, 0, 0]])
163
tensor([[ 1,  2,  3,  ...,  1,  2,  3],
        [ 0,  0,  0,  ..., 12, 12, 12]])
70
tensor([[ 1,  2,  3,  ...,  2, 11, 10],
        [ 0,  0,  0,  ..., 10, 10, 51]])
68
tensor([[ 1,  2,  3,  ...,  5,  2, 12],
        [ 0,  0,  0,  ..., 11, 11, 11]])
72
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
187
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
65
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
75
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
72
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
66
tensor([[ 1,  2,  3,  ..., 65, 65, 65],
        [ 0,  0,  0,  ..., 19, 20, 21]])
79
tensor([[ 1,  2,  3,  ..., 28,  2, 48],
        [ 0,  0,  0,  ..., 49, 49, 49]])
85
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,  ..., 0, 0, 0]])
72
tensor([[1, 2, 3,  ..., 1, 2, 3],
        [0, 0, 0,

In [13]:
# train_loop
def train_loop(train_set, num_nodes, state_channels, model, loss_fn, optimizer, device):

    hidden_encoder_global = create_hidden_global(num_nodes=num_nodes,out_channels=state_channels)
    hidden_decoder_global = create_hidden_global(num_nodes=num_nodes,out_channels=model.embedding_hidden_nums[-1])
    # cell_global = create_cell_global(num_nodes=len(global_nodes),out_channels=256)
    
    train_losses = []
    model.train()
    for i,snapshot in enumerate(train_set,start=1):
        if snapshot.node_attr.shape[0] == 0:
            # print("snapshot_{} has no data...".format(i))
            continue
        
        node_attr = snapshot.node_attr
        node_index = snapshot.node_index
        edge_attr = snapshot.edge_attr
        edge_index = snapshot.edge_index

        
        _node_attr = node_attr.to(device)
        _edge_attr = edge_attr.to(device)
        _edge_index = edge_index.to(device)
        
        pre_h_encoder = torch.tensor(select_hidden_local(hidden_encoder_global, node_index),dtype=torch.float32).to(device)
        pre_h_decoder = torch.tensor(select_hidden_local(hidden_decoder_global, node_index),dtype=torch.float32).to(device)
        
        print(_node_attr.type())
        # Compute prediction and loss
        outs, h_encoder, h_decoder = model(_node_attr,_edge_index,_edge_attr)

        update_hidden_gobal(hidden_encoder_global, h_encoder, node_index)
        update_hidden_gobal(hidden_decoder_global, h_decoder, node_index)

        train_loss = 0
        for i in range(node_attr.shape[0]):
#             train_loss += torch.sqrt(loss_f(torch.log(_node_attr[i]+1), torch.log(outs[i]+1)))
            train_loss += torch.sqrt(loss_f(_node_attr[i], outs[i]))
        
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        if(use_gpu):
            train_losses.append(train_loss.cpu().detach().numpy())
        else:
            train_losses.append(train_loss.detach().numpy())
    return train_losses

In [14]:
train_loop(train_set, train_set.node_attr.shape[0], model.gnn_out_channels, model, loss_f, optimizer, device)

torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.FloatTensor
torch.Floa

[array(2.3493378, dtype=float32),
 array(2.3336272, dtype=float32),
 array(108.26035, dtype=float32),
 array(45.413208, dtype=float32),
 array(44.236977, dtype=float32),
 array(46.69764, dtype=float32),
 array(127.73181, dtype=float32),
 array(42.241356, dtype=float32),
 array(48.34228, dtype=float32),
 array(46.534958, dtype=float32),
 array(42.703968, dtype=float32),
 array(50.730217, dtype=float32),
 array(54.502678, dtype=float32),
 array(46.273712, dtype=float32),
 array(47.68134, dtype=float32),
 array(48.71389, dtype=float32),
 array(47.98499, dtype=float32),
 array(50.367443, dtype=float32),
 array(43.04029, dtype=float32),
 array(44.147854, dtype=float32),
 array(52.13334, dtype=float32),
 array(45.880642, dtype=float32),
 array(62.836544, dtype=float32),
 array(44.025375, dtype=float32),
 array(54.171135, dtype=float32),
 array(45.621925, dtype=float32),
 array(44.78473, dtype=float32),
 array(50.396633, dtype=float32),
 array(49.668083, dtype=float32),
 array(64.06241, dtype

In [15]:
def test_loop(test_set, num_nodes, state_channels, model, loss_fn, device):
    size = len(train_set)

    hidden_encoder_global = create_hidden_global(num_nodes=num_nodes,out_channels=state_channels)
    hidden_decoder_global = create_hidden_global(num_nodes=num_nodes,out_channels=state_channels)
    # cell_global = create_cell_global(num_nodes=len(global_nodes),out_channels=256)
    
    test_losses = []
    with torch.no_grad():
        for i,snapshot in enumerate(test_set,start=1):
            if snapshot.node_attr.shape[0] == 0:
                # print("snapshot_{} has no data...".format(i))
                continue

            node_attr = snapshot.node_attr
            node_index = snapshot.node_index
            edge_attr = snapshot.edge_attr
            edge_index = snapshot.edge_index


            _node_attr = node_attr.to(device)
            _edge_attr = edge_attr.to(device)
            _edge_index = edge_index.to(device)

            pre_h_encoder = torch.tensor(select_hidden_local(hidden_encoder_global, node_index),dtype=torch.float32).to(device)
            pre_h_decoder = torch.tensor(select_hidden_local(hidden_decoder_global, node_index),dtype=torch.float32).to(device)

            # Compute prediction and loss
            outs, h_encoder, h_decoder = model(_node_attr,_edge_index,_edge_attr)

            update_hidden_gobal(hidden_encoder_global, h_encoder, node_index)
            update_hidden_gobal(hidden_decoder_global, h_decoder, node_index)

            test_loss = 0
            for i in range(node_attr.shape[0]):
    #             train_loss += torch.sqrt(loss_f(torch.log(_node_attr[i]+1), torch.log(outs[i]+1)))
                test_loss += torch.sqrt(loss_f(_node_attr[i], outs[i]))
        
        if(use_gpu):
            test_losses.append(test_loss.cpu().detach().numpy())
        else:
            test_losses.append(test_loss.detach().numpy())
    return test_losses
        