In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import mat73

from torch_geometric.data import Data as gData
from torch_geometric.utils import to_networkx, to_undirected
from torch_geometric.nn import MessagePassing

import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric_temporal.nn.recurrent import DCRNN

folder='./data/'
data = mat73.loadmat(folder+'LNP_spk_all.mat')

data.keys()

data['spikes'].shape

In [2]:
def build_edge_idx(num_nodes):
    # Initialize edge index matrix
    E = torch.zeros((2, num_nodes * (num_nodes - 1)), dtype=torch.long)
    
    # Populate 1st row
    for node in range(num_nodes):
        for neighbor in range(num_nodes - 1):
            E[0, node * (num_nodes - 1) + neighbor] = node

    # Populate 2nd row
    neighbors = []
    for node in range(num_nodes):
        neighbors.append(list(np.arange(node)) + list(np.arange(node+1, num_nodes)))
    E[1, :] = torch.Tensor([item for sublist in neighbors for item in sublist])
    
    return E

In [3]:
class graph_learning(MessagePassing):
    def __init__(self, nodes_num, embedding_dim):
        super(graph_learning, self).__init__(aggr=None) 
        
        self.embedding_dim = embedding_dim
        self.num_nodes = nodes_num
        
        self.conv1 = torch.nn.Conv1d(1, 8, 10, stride=1)  # .to(device)
        self.conv2 = torch.nn.Conv1d(8, 16, 10, stride=1)  # .to(device)
        self.hidden_drop = torch.nn.Dropout(0.2)
        self.fc = torch.nn.Linear(79712, self.embedding_dim)
        self.bn1 = torch.nn.BatchNorm1d(8)
        self.bn2 = torch.nn.BatchNorm1d(16)
        self.bn3 = torch.nn.BatchNorm1d(self.embedding_dim)
        self.fc_out = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
        self.fc_cat = nn.Linear(self.embedding_dim, 2)
#         self.fc_out = torch.nn.Linear(4, 4)
#         self.fc_cat = torch.nn.Linear(4, 2)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def forward(self, x, edge_index):
        batch_size = x.shape[0]
        
        print(x.shape)
        x = x.transpose(1, 0).reshape(self.num_nodes, 1, -1)
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = x.view(self.num_nodes, -1)
        x = self.fc(x)
        x = F.relu(x)
        x = self.bn3(x)
        
        print(f"before message passing:{x.shape}")
        
        x = self.propagate(edge_index,x=x)
        
        return x

    def message(self, x_i, x_j):
        print(f'x_i: {x_i.shape}')
        print(f'x_j: {x_j.shape}')
        
        x = torch.cat([x_i, x_j], dim=-1)
        print(x.shape)
#         x = torch.relu(self.fc_out(x))
#         print(x.shape)
        
        return x

    def update(self, aggr_out):
        print(f'update: {aggr_out.shape}')
        x = self.fc_cat(aggr_out)
        print(f'after fc_cat: {x.shape}')
        return x

    def aggregate(self, x):
        print(f'aggregate: {x.shape}')
        x = torch.relu(self.fc_out(x))
        print(f'after fc_out{x.shape}')
        return x

In [4]:
nodes = 100
node_feas = torch.rand(1,nodes,50,100)

In [5]:
edge_index = build_edge_idx(num_nodes=nodes)

In [6]:
edge_index

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 1,  2,  3,  ..., 96, 97, 98]])

In [7]:
gl = graph_learning(nodes, 50)

In [8]:
adj = gl(node_feas, edge_index)

torch.Size([1, 100, 50, 100])
torch.Size([100, 1, 5000])
torch.Size([100, 8, 4991])
before message passing:torch.Size([100, 50])
x_i: torch.Size([9900, 50])
x_j: torch.Size([9900, 50])
torch.Size([9900, 100])
aggregate: torch.Size([9900, 100])
after fc_outtorch.Size([9900, 50])
update: torch.Size([9900, 50])
after fc_cat: torch.Size([9900, 2])


In [9]:
z = F.gumbel_softmax(adj, tau=1, hard=True)

In [10]:
z_1 = torch.transpose(z, 0,1).detach()

In [11]:
z_1

tensor([[1., 0., 1.,  ..., 1., 1., 1.],
        [0., 1., 0.,  ..., 0., 0., 0.]])

In [12]:
edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:,ii])
        
new_edge = torch.stack(edge_, dim=-1)

In [13]:
new_edge

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 1,  3,  4,  ..., 96, 97, 98]])

In [14]:
adj = gl(node_feas, new_edge)

torch.Size([1, 100, 50, 100])
torch.Size([100, 1, 5000])
torch.Size([100, 8, 4991])
before message passing:torch.Size([100, 50])
x_i: torch.Size([8172, 50])
x_j: torch.Size([8172, 50])
torch.Size([8172, 100])
aggregate: torch.Size([8172, 100])
after fc_outtorch.Size([8172, 50])
update: torch.Size([8172, 50])
after fc_cat: torch.Size([8172, 2])


In [15]:
z = F.gumbel_softmax(adj, tau=1, hard=True)
z_1 = torch.transpose(z, 0,1).detach()

edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:,ii])
        
new_edge = torch.stack(edge_, dim=-1)

In [16]:
new_edge

tensor([[ 0,  0,  0,  ..., 82, 82, 82],
        [ 1,  2,  3,  ..., 51, 52, 53]])

In [17]:
new_edge.shape

torch.Size([2, 7108])

# GTS

In [18]:
from models.GTS.graph_learning import GlobalGraphLearning
from models.GTS.DCRNN import DCRNN
from torch_geometric_temporal.nn.recurrent import DCRNN as dcrnn

from glob import glob
import yaml
from easydict import EasyDict as edict

In [19]:
config_file = glob('./config/GTS/*.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [20]:
config

{'embedding_dim': 64,
 'nodes_num': 30,
 'kernel_size': 10,
 'stride': 1,
 'conv1_dim': 8,
 'conv2_dim': 16,
 'fc_dim': 15712,
 'node_features': 1,
 'diffusion_k': 3,
 'num_layer': 1,
 'tau': 0.3,
 'output_dim': 1,
 'encoder_length': 70,
 'decoder_length': 30}

In [21]:
node_feas = torch.rand(config.nodes_num, config.node_features, 1000)
edge_index = build_edge_idx(num_nodes=config.nodes_num)

In [22]:
node_feas.shape

torch.Size([30, 1, 1000])

In [23]:
gl = GlobalGraphLearning(config)

In [24]:
adj = gl(node_feas, edge_index)

torch.Size([30, 1, 1000])
torch.Size([30, 1, 1000])
torch.Size([30, 16, 982])
torch.Size([30, 15712])


In [25]:
z_1 = F.gumbel_softmax(adj, tau=0.3, hard=True)
z_1 = torch.transpose(z_1, 0, 1)

In [26]:
edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:, ii])

adj_matrix = torch.stack(edge_, dim=-1)

In [27]:
encoder_dcrnn = DCRNN(config)
decoder_dcrnn = DCRNN(config)

In [28]:
inputs = node_feas.reshape(-1,30,1,100)
inputs.shape
# batch, node_num, node_feature, total time step

torch.Size([10, 30, 1, 100])

In [32]:
dcgru = dcrnn(64, 64, 3)

In [44]:
inputs[0][:,:,0].shape

torch.Size([30, 1])

In [43]:
dcgru(inputs[0][:,:,0], adj_matrix)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x65 and 128x64)

In [30]:
encoder_hidden_state = encoder_dcrnn(inputs[0].reshape(config.nodes_num,-1), adj_matrix)

torch.Size([30, 100])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x101 and 2x1)

In [45]:
from models.GTS.DCGRU import DCRNN

In [46]:
dcrnn = DCRNN(64, 64, 3)

In [47]:
dcrnn(inputs[0][:,:,0], adj_matrix)

hidden state dim:torch.Size([30, 64])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (30x65 and 128x64)