In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import mat73

from torch_geometric.data import Data as gData
from torch_geometric.utils import to_networkx, to_undirected
from torch_geometric.nn import MessagePassing

import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric_temporal.nn.recurrent import DCRNN

In [2]:
folder='./data/'
data = mat73.loadmat(folder+'LNP_spk_all.mat')

In [4]:
data.keys()

dict_keys(['spikes'])

In [6]:
data['spikes'].shape

(4800000, 100)

In [11]:
def build_edge_idx(num_nodes):
    # Initialize edge index matrix
    E = torch.zeros((2, num_nodes * (num_nodes - 1)), dtype=torch.long)
    
    # Populate 1st row
    for node in range(num_nodes):
        for neighbor in range(num_nodes - 1):
            E[0, node * (num_nodes - 1) + neighbor] = node

    # Populate 2nd row
    neighbors = []
    for node in range(num_nodes):
        neighbors.append(list(np.arange(node)) + list(np.arange(node+1, num_nodes)))
    E[1, :] = torch.Tensor([item for sublist in neighbors for item in sublist])
    
    return E

In [19]:
class graph_learning(MessagePassing):
    def __init__(self, nodes_num, embedding_dim):
        super(graph_learning, self).__init__(aggr=None) 
        
        self.embedding_dim = embedding_dim
        self.num_nodes = nodes_num
        
        self.conv1 = torch.nn.Conv1d(1, 8, 10, stride=1)  # .to(device)
        self.conv2 = torch.nn.Conv1d(8, 16, 10, stride=1)  # .to(device)
        self.hidden_drop = torch.nn.Dropout(0.2)
        self.fc = torch.nn.Linear(79712, self.embedding_dim)
        self.bn1 = torch.nn.BatchNorm1d(8)
        self.bn2 = torch.nn.BatchNorm1d(16)
        self.bn3 = torch.nn.BatchNorm1d(self.embedding_dim)
        self.fc_out = nn.Linear(self.embedding_dim * 2, self.embedding_dim)
        self.fc_cat = nn.Linear(self.embedding_dim, 2)
#         self.fc_out = torch.nn.Linear(4, 4)
#         self.fc_cat = torch.nn.Linear(4, 2)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight.data)
                m.bias.data.fill_(0.1)

    def forward(self, x, edge_index):
        batch_size = x.shape[0]
        
        print(x.shape)
        x = x.transpose(1, 0).reshape(self.num_nodes, 1, -1)
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = F.relu(x)
        x = self.bn1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = x.view(self.num_nodes, -1)
        x = self.fc(x)
        x = F.relu(x)
        x = self.bn3(x)
        
        print(f"before message passing:{x.shape}")
        
        x = self.propagate(edge_index,x=x)
        
        return x

    def message(self, x_i, x_j):
        print(f'x_i: {x_i.shape}')
        print(f'x_j: {x_j.shape}')
        
        x = torch.cat([x_i, x_j], dim=-1)
        print(x.shape)
#         x = torch.relu(self.fc_out(x))
#         print(x.shape)
        
        return x

    def update(self, aggr_out):
        print(f'update: {aggr_out.shape}')
        x = self.fc_cat(aggr_out)
        print(f'after fc_cat: {x.shape}')
        return x

    def aggregate(self, x):
        print(f'aggregate: {x.shape}')
        x = torch.relu(self.fc_out(x))
        print(f'after fc_out{x.shape}')
        return x

In [20]:
nodes = 100
node_feas = torch.rand(1,nodes,50,100)

In [21]:
edge_index = build_edge_idx(num_nodes=nodes)

In [22]:
edge_index

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 1,  2,  3,  ..., 96, 97, 98]])

In [23]:
gl = graph_learning(nodes, 50)

In [24]:
adj = gl(node_feas, edge_index)

torch.Size([1, 100, 50, 100])
torch.Size([100, 1, 5000])
torch.Size([100, 8, 4991])
before message passing:torch.Size([100, 50])
x_i: torch.Size([9900, 50])
x_j: torch.Size([9900, 50])
torch.Size([9900, 100])
aggregate: torch.Size([9900, 100])
after fc_outtorch.Size([9900, 50])
update: torch.Size([9900, 50])
after fc_cat: torch.Size([9900, 2])


In [25]:
z = F.gumbel_softmax(adj, tau=1, hard=True)

In [26]:
z_1 = torch.transpose(z, 0,1).detach()

In [27]:
z_1

tensor([[0., 0., 0.,  ..., 0., 1., 1.],
        [1., 1., 1.,  ..., 1., 0., 0.]])

In [28]:
edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:,ii])
        
new_edge = torch.stack(edge_, dim=-1)

In [29]:
new_edge

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [10, 12, 15,  ..., 95, 97, 98]])

In [30]:
adj = gl(node_feas, new_edge)

torch.Size([1, 100, 50, 100])
torch.Size([100, 1, 5000])
torch.Size([100, 8, 4991])
before message passing:torch.Size([100, 50])
x_i: torch.Size([3167, 50])
x_j: torch.Size([3167, 50])
torch.Size([3167, 100])
aggregate: torch.Size([3167, 100])
after fc_outtorch.Size([3167, 50])
update: torch.Size([3167, 50])
after fc_cat: torch.Size([3167, 2])


In [31]:
z = F.gumbel_softmax(adj, tau=1, hard=True)
z_1 = torch.transpose(z, 0,1).detach()

edge_ = []

for ii, rel in enumerate(z_1[0]):
    if bool(rel):
        edge_.append(edge_index[:,ii])
        
new_edge = torch.stack(edge_, dim=-1)

In [32]:
new_edge

tensor([[ 0,  0,  0,  ..., 31, 31, 31],
        [ 1,  2,  6,  ..., 95, 96, 98]])

In [33]:
new_edge.shape

torch.Size([2, 1607])

In [34]:
class RecurrentGCN(torch.nn.Module):
    def __init__(self, node_features):
        super(RecurrentGCN, self).__init__()
        self.recurrent = DCRNN(node_features, 32, 1)
        self.linear = torch.nn.Linear(32, 1)

    def forward(self, x, edge_index, edge_weight):
        h = self.recurrent(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [35]:
from tqdm import tqdm

model = RecurrentGCN(node_features = 50)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()

RecurrentGCN(
  (recurrent): DCRNN(
    (conv_x_z): DConv()
    (conv_x_r): DConv()
    (conv_x_h): DConv()
  )
  (linear): Linear(in_features=32, out_features=1, bias=True)
)

In [36]:
dcrnn = DCRNN(50, 50, 3)

In [None]:
x_1.shape