In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import mat73

from torch_geometric.data import Data as gData
from torch_geometric.utils import to_networkx, to_undirected
from torch_geometric.nn import MessagePassing

import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric_temporal.nn.recurrent import DCRNN

# GTS

In [2]:
from models.GTS.graph_learning import GlobalGraphLearning
from models.GTS.DCRNN import DCRNN
# from torch_geometric_temporal.nn.recurrent import DCRNN as dcrnn
from utils.utils import build_edge_idx

from glob import glob
import yaml
from easydict import EasyDict as edict

In [3]:
config_file = glob('./config/GTS/*.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [4]:
node_feas = torch.rand(config.nodes_num, config.node_features, 1000)
edge_index = build_edge_idx(num_nodes=config.nodes_num)

In [5]:
edge_index.shape

torch.Size([2, 870])

In [6]:
node_feas.shape

torch.Size([30, 1, 1000])

In [7]:
gl = GlobalGraphLearning(config)

In [8]:
adj = gl(node_feas, edge_index)

torch.Size([30, 1, 1000])
torch.Size([30, 1, 1000])
torch.Size([30, 16, 982])
torch.Size([30, 15712])


In [9]:
z_1 = F.gumbel_softmax(adj, tau=0.3, hard=True)
z_1 = torch.transpose(z_1, 0, 1)

In [10]:
edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:, ii])

adj_matrix = torch.stack(edge_, dim=-1)

In [11]:
config

{'embedding_dim': 64,
 'nodes_num': 30,
 'kernel_size': 10,
 'stride': 1,
 'conv1_dim': 8,
 'conv2_dim': 16,
 'fc_dim': 15712,
 'node_features': 1,
 'diffusion_k': 3,
 'num_layer': 1,
 'tau': 0.3,
 'output_dim': 1,
 'encoder_length': 70,
 'decoder_length': 30}

In [12]:
config.num_layer=2

In [13]:
config.embedding_dim = 13
config.in_channels = 1

In [14]:
encoder_dcrnn = DCRNN(config)
decoder_dcrnn = DCRNN(config)

In [15]:
inputs = node_feas.reshape(-1,30,1,100)
inputs.shape
# batch, node_num, node_feature, total time step

torch.Size([10, 30, 1, 100])

In [23]:
inputs[:,:,:,0].shape

torch.Size([10, 30, 1])

In [27]:
encoder_dcrnn.eval()

DCRNN(
  (recurrent): ModuleList(
    (0): DCRNN(
      (conv_x_z): DConv()
      (conv_x_r): DConv()
      (conv_x_h): DConv()
    )
    (1): DCRNN(
      (conv_x_z): DConv()
      (conv_x_r): DConv()
      (conv_x_h): DConv()
    )
  )
)

In [26]:
encoder_dcrnn.training

True

In [29]:
encoder_hidden_state = encoder_dcrnn(inputs[0,:,:,0], adj_matrix)

0
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 14])
1
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 26])


In [30]:
encoder_hidden_state_2 = encoder_dcrnn(inputs[0,:,:,1], adj_matrix, hidden_state=encoder_hidden_state)

0
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 14])
1
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 26])


In [31]:
decoder_hidden_state = decoder_dcrnn(inputs[0,:,:,3], adj_matrix, hidden_state=encoder_hidden_state_2)

0
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 14])
1
Initial Hidden State Dim: torch.Size([30, 13])
Z shape: torch.Size([30, 26])


In [32]:
decoder_hidden_state[-1].shape

torch.Size([30, 13])

In [34]:
prediction_layer = nn.Linear(config.embedding_dim, config.output_dim)

In [35]:
prediction_layer

Linear(in_features=13, out_features=1, bias=True)

In [37]:
prediction_layer(decoder_hidden_state[-1]).shape

torch.Size([30, 1])