In [1]:
import sys

sys.path.append('../GraphStructureLearning')

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
import pickle

from torch_geometric.data import Data as gData
from torch_geometric.utils import to_networkx, to_undirected
from torch_geometric.nn import MessagePassing

import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric_temporal.nn.recurrent import DCRNN

# GTS

In [33]:
from models.GTS.gts_graph_learning import GTS_Graph_Learning
from models.GTS.gts_forecasting_module import GTS_Forecasting_Module
from models.GTS.DCRNN import DCRNN
from utils.utils import build_fully_connected_edge_idx, build_batch_edge_index

from glob import glob
import yaml
from easydict import EasyDict as edict

In [4]:
config_file = glob('./config/GTS/*.yaml')[0]
config = edict(yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader))

In [5]:
config

{'exp_name': 1,
 'exp_dir': './exp',
 'use_gpu': False,
 'device': 'cpu',
 'seed': 1010,
 'model_name': 'GTS',
 'graph_learning_module': 'GTS',
 'graph_forecasting_module': 'GTS',
 'initial_edge_index': 'Fully Connected',
 'dataset': {'root': './data/spike_lambda_bin100',
  'name': 'spike_lambda_bin100',
  'total_time_length': 4800,
  'idx_ratio': 0.5,
  'window_size': 20,
  'slide': 5,
  'pred_step': 5,
  'train_valid_test': [4000, 4400, 4800],
  'save': './data/spike_lambda_bin100/'},
 'train': {'optimizer': 'Adam',
  'epoch': 2,
  'loss_function': 'MSELoss',
  'lr': 0.001,
  'momentum': 0.9,
  'wd': 0.0,
  'batch_size': 1,
  'lr_decay': 0.1,
  'lr_decay_steps': [10000]},
 'nodes_num': 100,
 'node_features': 1,
 'hidden_dim': 16,
 'embedding_dim': 16,
 'graph_learning': {'mode': 'weight',
  'to_symmetric': True,
  'kernel_size': [200, 80, 10],
  'stride': [20, 10, 5],
  'conv1_dim': 4,
  'conv2_dim': 4,
  'conv3_dim': 4},
 'forecasting_module': {'diffusion_k': 1,
  'num_layer': 1,
  

In [6]:
# node_feas = torch.rand(config.nodes_num, config.node_features, 1000)
# edge_index = build_fully_connected_edge_idx(num_nodes=config.nodes_num)

In [7]:
spike = pickle.load(open('./data/spk_bin_n100.pickle', 'rb'))

spike = torch.FloatTensor(spike[:,:4800])

In [8]:
config.nodes_num = 100

In [9]:
edge_index = build_fully_connected_edge_idx(num_nodes=config.nodes_num)

In [10]:
edge_index.shape

torch.Size([2, 9900])

In [11]:
spike.shape

torch.Size([100, 4800])

In [12]:
gl = GTS_Graph_Learning(config)

In [13]:
gl

GTS_Graph_Learning(
  (conv1): Conv1d(1, 4, kernel_size=(200,), stride=(20,))
  (conv2): Conv1d(4, 4, kernel_size=(80,), stride=(10,))
  (conv3): Conv1d(4, 4, kernel_size=(10,), stride=(5,))
  (fc_conv): Conv1d(4, 1, kernel_size=(1,), stride=(1,))
  (fc): Linear(in_features=2, out_features=16, bias=True)
  (hidden_drop): Dropout(p=0.2, inplace=False)
  (bn1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc_cat): Linear(in_features=32, out_features=16, bias=True)
  (fc_out): Linear(in_features=16, out_features=1, bias=True)
)

In [14]:
adj = gl(spike, edge_index)

In [15]:
adj.shape

torch.Size([9900, 1])

In [16]:
adj = adj.T

In [17]:
adj.shape

torch.Size([1, 9900])

In [None]:
z_1 = F.gumbel_softmax(adj, hard=True)
z_1

In [None]:
z_1.shape

In [None]:
z_1[:,0].shape

In [None]:
edge_index.shape

In [None]:
z_adj = torch.where(z_1[:,0])

In [None]:
z_adj

In [None]:
b= edge_index[0,:][z_adj]

In [None]:
c = edge_index[1,:][z_adj]

In [None]:
a = torch.stack([b,c])

In [None]:
a

In [25]:
from torch_geometric.utils import *

In [None]:
is_undirected(a)

In [None]:
to_undirected(a)

In [18]:
from torch_geometric.data import Data, Batch, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader

In [20]:
d = Data(x=spike, edge_index=edge_index, edge_attr=adj)

In [None]:
from torch_scatter import scatter


def to_dense_adj(edge_index, batch=None, edge_attr=None, max_num_nodes=None):
    r"""Converts batched sparse adjacency matrices given by edge indices and
    edge attributes to a single dense batched adjacency matrix.
    Args:
        edge_index (LongTensor): The edge indices.
        batch (LongTensor, optional): Batch vector
            :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
            node to a specific example. (default: :obj:`None`)
        edge_attr (Tensor, optional): Edge weights or multi-dimensional edge
            features. (default: :obj:`None`)
        max_num_nodes (int, optional): The size of the output node dimension.
            (default: :obj:`None`)
    :rtype: :class:`Tensor`
    """
    if batch is None:
        batch = edge_index.new_zeros(edge_index.max().item() + 1)

    batch_size = batch.max().item() + 1
    one = batch.new_ones(batch.size(0))
    num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce='add')
    cum_nodes = torch.cat([batch.new_zeros(1), num_nodes.cumsum(dim=0)])

    idx0 = batch[edge_index[0]]
    idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]]
    idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]]

    if max_num_nodes is None:
        max_num_nodes = num_nodes.max().item()

    elif idx1.max() >= max_num_nodes or idx2.max() >= max_num_nodes:
        mask = (idx1 < max_num_nodes) & (idx2 < max_num_nodes)
        idx0 = idx0[mask]
        idx1 = idx1[mask]
        idx2 = idx2[mask]
        edge_attr = None if edge_attr is None else edge_attr[mask]

    if edge_attr is None:
        edge_attr = torch.ones(idx0.numel(), device=edge_index.device)

    size = [batch_size, max_num_nodes, max_num_nodes]
    size += list(edge_attr.size())[1:]
    adj = torch.zeros(size, dtype=edge_attr.dtype, device=edge_index.device)

    flattened_size = batch_size * max_num_nodes * max_num_nodes
    adj = adj.view([flattened_size] + list(adj.size())[3:])
    idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2
    scatter(edge_attr, idx, dim=0, out=adj, reduce='add')
    adj = adj.view(size)

    return adj

In [30]:
mat = to_dense_adj(edge_index, edge_attr=adj.T).squeeze()

In [31]:
mat.shape

torch.Size([100, 100])

In [32]:
mat

tensor([[0.0000, 0.6418, 0.8678,  ..., 0.8678, 0.8216, 0.8684],
        [0.7638, 0.0000, 0.9558,  ..., 0.9558, 0.9096, 0.9577],
        [0.4511, 0.4381, 0.0000,  ..., 0.5837, 0.5531, 0.5421],
        ...,
        [0.4511, 0.4381, 0.5837,  ..., 0.0000, 0.5531, 0.5421],
        [0.5148, 0.4944, 0.6722,  ..., 0.6722, 0.0000, 0.6306],
        [0.3747, 0.3681, 0.5120,  ..., 0.5120, 0.4828, 0.0000]],
       grad_fn=<SqueezeBackward0>)

In [66]:
adj.shape

torch.Size([9900, 1])

In [34]:
new_edge = build_batch_edge_index(edge_index, 3)

In [35]:
new_edge

tensor([[  0,   0,   0,  ..., 299, 299, 299],
        [  1,   2,   3,  ..., 296, 297, 298]])

In [36]:
new_edge.shape

torch.Size([2, 29700])

In [37]:
9900*3

29700

In [40]:
adj.shape

torch.Size([9900, 1])

In [48]:
[adj,adj,adj]

[tensor([[0.6418],
         [0.8678],
         [0.8678],
         ...,
         [0.4806],
         [0.5120],
         [0.4828]], grad_fn=<PermuteBackward0>),
 tensor([[0.6418],
         [0.8678],
         [0.8678],
         ...,
         [0.4806],
         [0.5120],
         [0.4828]], grad_fn=<PermuteBackward0>),
 tensor([[0.6418],
         [0.8678],
         [0.8678],
         ...,
         [0.4806],
         [0.5120],
         [0.4828]], grad_fn=<PermuteBackward0>)]

In [54]:
a = torch.stack([adj,adj,adj], dim=0).view(-1,1)

In [55]:
a[:9900]

tensor([[0.6418],
        [0.8678],
        [0.8678],
        ...,
        [0.4806],
        [0.5120],
        [0.4828]], grad_fn=<SliceBackward0>)

In [57]:
a[9900]

tensor([0.6418], grad_fn=<SelectBackward0>)

In [59]:
adj + adj

tensor([[1.2836],
        [1.7356],
        [1.7356],
        ...,
        [0.9613],
        [1.0239],
        [0.9657]], grad_fn=<AddBackward0>)

In [64]:
out_list = []
for _ in range(1):
    out_list.append(adj)
output = torch.stack(out_list).view(-1,1)

In [65]:
output.shape

torch.Size([9900, 1])

In [105]:
temp = add_self_loops(edge_index, adj, 1)

In [106]:
temp[0].shape

torch.Size([2, 10000])

In [107]:
to_dense_adj(temp[0], edge_attr=temp[1]).squeeze()

tensor([[1.0000, 0.6418, 0.8678,  ..., 0.8678, 0.8216, 0.8684],
        [0.7638, 1.0000, 0.9558,  ..., 0.9558, 0.9096, 0.9577],
        [0.4511, 0.4381, 1.0000,  ..., 0.5837, 0.5531, 0.5421],
        ...,
        [0.4511, 0.4381, 0.5837,  ..., 1.0000, 0.5531, 0.5421],
        [0.5148, 0.4944, 0.6722,  ..., 0.6722, 1.0000, 0.6306],
        [0.3747, 0.3681, 0.5120,  ..., 0.5120, 0.4828, 1.0000]],
       grad_fn=<SqueezeBackward0>)