## Step 1: Imports

In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.utils
import torch.utils.data
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt 
from scipy.ndimage import rotate
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal
import scipy.sparse as sp
from scipy.linalg import block_diag
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import torch.nn.functional as F
import copy
import time
from torch_scatter import scatter_mean, scatter_max, scatter_add
from torch_geometric.utils import remove_self_loops,add_self_loops
from torch_geometric.datasets import Planetoid
import networkx as nx
import scipy.io as sio
import torch_scatter
import inspect
import pickle

## Step 2: Utility functions

In [2]:

def uniform(size, tensor):
    stdv = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)


def glorot(tensor):
    stdv = math.sqrt(6.0 / (tensor.size(0) + tensor.size(1)))
    if tensor is not None:
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    if tensor is not None:
        tensor.data.fill_(0)


def ones(tensor):
    if tensor is not None:
        tensor.data.fill_(1)


def reset(nn):
    def _reset(item):
        if hasattr(item, 'reset_parameters'):
            item.reset_parameters()

    if nn is not None:
        if hasattr(nn, 'children') and len(list(nn.children())) > 0:
            for item in nn.children():
                _reset(item)
        else:
            _reset(nn)


def scatter_(name, src, index, dim_size=None):
    r"""Aggregates all values from the :attr:`src` tensor at the indices
    specified in the :attr:`index` tensor along the first dimension.
    If multiple indices reference the same location, their contributions
    are aggregated according to :attr:`name` (either :obj:`"add"`,
    :obj:`"mean"` or :obj:`"max"`).
    Args:
        name (string): The aggregation to use (:obj:`"add"`, :obj:`"mean"`,
            :obj:`"max"`).
        src (Tensor): The source tensor.
        index (LongTensor): The indices of elements to scatter.
        dim_size (int, optional): Automatically create output tensor with size
            :attr:`dim_size` in the first dimension. If set to :attr:`None`, a
            minimal sized output tensor is returned. (default: :obj:`None`)
    :rtype: :class:`Tensor`
    """

    assert name in ['add', 'mean', 'max']

    op = getattr(torch_scatter, 'scatter_{}'.format(name))
    fill_value = -1e38 if name is 'max' else 0   
    out = op(src, index, 0, None, dim_size)
    if isinstance(out, tuple):
        out = out[0]

    if name is 'max':
        out[out == fill_value] = 0

    return out


def sparse_to_tuple(sparse_mx):
    if not sp.isspmatrix_coo(sparse_mx):
        sparse_mx = sparse_mx.tocoo()
    coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
    values = sparse_mx.data
    shape = sparse_mx.shape
    return coords, values, shape

  fill_value = -1e38 if name is 'max' else 0
  if name is 'max':


In [3]:
# Example usage
size = 10
tensor = torch.empty(size)
uniform(size, tensor)
print(tensor)

tensor([-0.1213,  0.0715,  0.0323, -0.1599,  0.1078,  0.0466, -0.2678, -0.0564,
        -0.2276,  0.0480])


In [4]:
# Example usage
        
# Define a tensor with size (3, 4)
tensor = torch.empty(3, 4)

# Initialize the tensor using the glorot function
glorot(tensor)

# Print the initialized tensor
print(tensor)

tensor([[-0.6866,  0.0657, -0.4963, -0.0306],
        [ 0.1382, -0.2459, -0.2515, -0.1995],
        [ 0.3815,  0.7853, -0.0521, -0.8321]])


In [5]:
# Example usage
# Define a simple neural network
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# Create an instance of the neural network
net = MyNet()

# Print the initial parameters of the neural network
print("Initial parameters:")
for name, param in net.named_parameters():
    print(name, param)


reset(net)

# Print the reset parameters of the neural network
print("\nReset parameters:")
for name, param in net.named_parameters():
    print(name, param)

Initial parameters:
fc1.weight Parameter containing:
tensor([[-0.2286,  0.1802,  0.2226,  0.3117, -0.2274,  0.0113, -0.0683, -0.1334,
         -0.2565, -0.0300],
        [-0.2392,  0.2775, -0.2181, -0.1134,  0.2712,  0.0658,  0.1822, -0.2199,
         -0.0749,  0.0434],
        [ 0.1919,  0.1999, -0.1146, -0.2826,  0.0610, -0.2276, -0.2894,  0.2838,
          0.1042,  0.0110],
        [-0.2643,  0.1969,  0.1768, -0.1445,  0.1862,  0.2273, -0.2724, -0.1970,
          0.1848, -0.2721],
        [ 0.2959,  0.0982, -0.1865, -0.0401,  0.2100,  0.2835,  0.3000, -0.1574,
          0.1678,  0.1703]], requires_grad=True)
fc1.bias Parameter containing:
tensor([-0.2768,  0.2358, -0.1112, -0.0792, -0.1854], requires_grad=True)
fc2.weight Parameter containing:
tensor([[ 0.3381,  0.2889, -0.1800, -0.1641, -0.3632],
        [-0.2566,  0.3771,  0.2301, -0.0523, -0.1916]], requires_grad=True)
fc2.bias Parameter containing:
tensor([ 0.1565, -0.2497], requires_grad=True)

Reset parameters:
fc1.weight Para

In [6]:
# Example usage

 # Define the source tensor
src = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# Define the index tensor
index = torch.tensor([0, 1, 0])

# Call the scatter_() function with the "add" aggregation
out = scatter_('add', src, index)

# Print the output tensor
print(out)

tensor([[ 8, 10, 12],
        [ 4,  5,  6]])


In [7]:

class MessagePassing(torch.nn.Module):
    r"""Base class for creating message passing layers
    .. math::
        \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i,
        \square_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}}
        \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{i,j}\right) \right),
    where :math:`\square` denotes a differentiable, permutation invariant
    function, *e.g.*, sum, mean or max, and :math:`\gamma_{\mathbf{\Theta}}`
    and :math:`\phi_{\mathbf{\Theta}}` denote differentiable functions such as
    MLPs.
    See `here <https://rusty1s.github.io/pytorch_geometric/build/html/notes/
    create_gnn.html>`__ for the accompanying tutorial.
    """

    def __init__(self, aggr='add'):
        super(MessagePassing, self).__init__()


        message_signature = inspect.signature(self.message)
        # Extract the names of parameters excluding the first one (self for methods)
        self.message_args = [param.name for param in message_signature.parameters.values()][0:]
        update_signature = inspect.signature(self.update)
        # Extract the names of parameters excluding the first one (self for methods)
        self.update_args = [param.name for param in update_signature.parameters.values()][1:]

    def propagate(self, aggr, edge_index, **kwargs):
        r"""The initial call to start propagating messages.
        Takes in an aggregation scheme (:obj:`"add"`, :obj:`"mean"` or
        :obj:`"max"`), the edge indices, and all additional data which is
        needed to construct messages and to update node embeddings."""

        assert aggr in ['add', 'mean', 'max']
        kwargs['edge_index'] = edge_index

        size = None
        message_args = []
        for arg in self.message_args:
            if arg[-2:] == '_i':
                tmp = kwargs[arg[:-2]]
                size = tmp.size(0)
                message_args.append(tmp[edge_index[0]])
            elif arg[-2:] == '_j':
                tmp = kwargs[arg[:-2]]
                size = tmp.size(0)
                message_args.append(tmp[edge_index[1]])
            else:
                message_args.append(kwargs[arg])

        update_args = [kwargs[arg] for arg in self.update_args]

        # Ensure there is at least one argument for the message function
        if not message_args:
            message_args.append(kwargs['x'])  # Use the node features as a default

        out = self.message(*message_args)
        out = scatter_(aggr, out, edge_index[0], dim_size=size)
        out = self.update(out, *update_args)
        return out

    def message(self, x_j):  # pragma: no cover
        r"""Constructs messages in analogy to :math:`\phi_{\mathbf{\Theta}}`
        for each edge in :math:`(i,j) \in \mathcal{E}`.
        Can take any argument which was initially passed to :meth:`propagate`.
        In addition, features can be lifted to the source node :math:`i` and
        target node :math:`j` by appending :obj:`_i` or :obj:`_j` to the
        variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`."""

        return x_j


    def update(self, aggr_out):  # pragma: no cover
        r"""Updates node embeddings in analogy to
        :math:`\gamma_{\mathbf{\Theta}}` for each node
        :math:`i \in \mathcal{V}`.
        Takes in the output of aggregation as first argument and any argument
        which was initially passed to :meth:`propagate`."""

        return aggr_out
    

# Example usage

# Create an instance of the MessagePassing class
mp = MessagePassing()

# Define the graph with three nodes and two directed edges
edge_index = torch.tensor([[0, 1, 1, 2, 2], [1, 0, 2, 1, 2]], dtype=torch.long)

# Create initial node features
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float)



# Propagate messages and update node embeddings
output = mp.propagate('add', edge_index, x=x)

# Print the results
print("Original Node Features:")
print(x)
print("\nEdge Index:")
print(edge_index)
print("\nOutput after Propagation and Update:")
print(output)


Original Node Features:
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])

Edge Index:
tensor([[0, 1, 1, 2, 2],
        [1, 0, 2, 1, 2]])

Output after Propagation and Update:
tensor([[ 3.,  4.],
        [ 6.,  8.],
        [ 8., 10.]])


In [8]:
def tuple_to_array(lot):
    out = np.array(list(lot[0]))
    for i in range(1, len(lot)):
        out = np.vstack((out, np.array(list(lot[i]))))
    
    return out

In [9]:
# Example usage
lot = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
result = tuple_to_array(lot)
print(result)

[[1 2 3]
 [4 5 6]
 [7 8 9]]


## Step 3: Edge Extraction

In [10]:
def extract_edges(adjs_list):
    edges_list = []
    for i in range(0, len(adjs_list)):
        
        adj = adjs_list[i]
        adj.eliminate_zeros()        
        adj_triu = sp.triu(adj)
        adj_tuple = sparse_to_tuple(adj_triu)
        edges = adj_tuple[0]
        edges_list.append(edges)
    return edges_list


In [11]:
# Example usage
# Create an adjacency matrix

temp_matrix = np.random.randint(2, size=(30, 30))
temp_adj = np.triu(temp_matrix) + np.triu(temp_matrix, 1).T
adj = [sp.csr_matrix(temp_adj)]

edge_lists = extract_edges(adj)

print("edge lists = ", edge_lists)

edge lists =  [array([[ 0,  3],
       [ 0,  4],
       [ 0,  5],
       [ 0,  7],
       [ 0,  8],
       [ 0, 11],
       [ 0, 12],
       [ 0, 15],
       [ 0, 20],
       [ 0, 23],
       [ 0, 25],
       [ 0, 26],
       [ 1,  1],
       [ 1,  2],
       [ 1,  3],
       [ 1,  8],
       [ 1,  9],
       [ 1, 10],
       [ 1, 13],
       [ 1, 14],
       [ 1, 16],
       [ 1, 17],
       [ 1, 19],
       [ 1, 20],
       [ 1, 21],
       [ 1, 22],
       [ 1, 23],
       [ 1, 24],
       [ 1, 28],
       [ 1, 29],
       [ 2,  4],
       [ 2,  6],
       [ 2,  7],
       [ 2,  8],
       [ 2, 11],
       [ 2, 13],
       [ 2, 15],
       [ 2, 18],
       [ 2, 19],
       [ 2, 20],
       [ 2, 21],
       [ 2, 24],
       [ 2, 25],
       [ 2, 26],
       [ 2, 29],
       [ 3,  4],
       [ 3,  5],
       [ 3,  7],
       [ 3, 10],
       [ 3, 15],
       [ 3, 17],
       [ 3, 18],
       [ 3, 20],
       [ 3, 22],
       [ 3, 23],
       [ 3, 25],
       [ 3, 27],
       [ 3, 28],

## Step 4: GCN, Temporal Attention, GRU Layers

In [12]:
class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, act=F.relu, improved=False, bias=False):
        super(GCNConv, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.act = act

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)

    def forward(self, x, edge_index, edge_weight=None):
        if edge_weight is None:
            edge_weight = torch.ones(
                (edge_index.size(1), ), dtype=x.dtype, device=x.device)
        edge_weight = edge_weight.view(-1)
        assert edge_weight.size(0) == edge_index.size(1)

        #if your original networks do not have self-loops, uncomment the following:
        
        # edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
        # loop_weight = torch.full(
        #     (x.size(0), ),
        #     1 if not self.improved else 2,
        #     dtype=x.dtype,
        #     device=x.device)
        # edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=x.size(0))
        # print('deg= ',deg)
        deg_inv = deg.pow(-0.5)
        deg_inv[deg_inv == float('inf')] = 0
        # print("deg_inv = ",deg_inv)
        norm = deg_inv[row] * edge_weight * deg_inv[col]
        # print('norm = ',norm)
        # print('weigh = ', self.weight)
        x = torch.matmul(x, self.weight)
        # print('X = ',x)
        out = self.propagate('add', edge_index, x=x, norm=norm)
        return self.act(out)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [13]:
edge_index = torch.tensor([[0, 1, 1, 2, 2], [1, 0, 2, 1, 2]], dtype=torch.long)

# Create initial node features
x = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=torch.float)

In [14]:
gcn = GCNConv(2, 2, act=lambda x:x,  improved=False, bias=False)
gcn.forward( x, edge_index, edge_weight=None)

tensor([[-1.9737,  0.7632],
        [-2.9930,  1.2194],
        [-3.6995,  1.2808]], grad_fn=<ScatterAddBackward0>)

In [15]:
class AttentionLayer(nn.Module):
    """Implements an Attention Layer"""

    def __init__(self, cuda, nhid):
        super(AttentionLayer, self).__init__()
        self.nhid = nhid


        # Linear transformations for K, Q, V from the same source
        self.key = nn.Linear(nhid, nhid)
        self.query = nn.Linear(nhid, nhid)
        self.value = nn.Linear(nhid, nhid)
        self.softmax = nn.Softmax()
        self.cuda = cuda
        self.attention_weights = []

    def forward(self, all_h, attention_width=5,mask=None):

        temp = [tensor[-1, :, :] for tensor in all_h]
        temp = (torch.stack(temp, dim=0))
        all_attentions_weights = []
        if temp.size(0) >= attention_width:
            lb = temp.size(0) - attention_width
            if(lb<0):
                lb = 0

            query = self.query(temp[-1,-1,:].view(1,-1))
            # print('shape query =', query.size())
            keys = self.key(temp[lb:,-1 ,:])
            # print('shape keys =', keys.size())
            scores = torch.matmul(query, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.nhid, dtype=torch.float32))
                
            # Apply mask (if provided)
            if mask is not None:
                scores = scores.masked_fill(mask == 0, -1e9)

            # Apply softmax
            attention_weights = F.softmax(scores, dim=1)
            all_attentions_weights.append(attention_weights)

            # print('att weights size =', attention_weights.size())
            
            values = self.value( (torch.stack(all_h[lb:], dim=0)))
            # print('shape values =', values.size())
            
            expanded_weights = attention_weights.view(1, attention_weights.size(1), 1, 1, 1)
            
            weighted_values = values * expanded_weights
            # print('wieghted values size = ', weighted_values.size())
            
            # Sum along the first dimension to get the final weighted sum tensor 
            weighted_sum = torch.sum(weighted_values, dim=1).squeeze(dim=0)
            # print('final tensor size = ', weighted_sum.size())
            all_h[-1] = weighted_sum

       
        return all_h,all_attentions_weights
    
        



In [16]:
# Example usage
temp_tensor = [torch.rand(1,3,4) for _ in range(6)]
att_layer = AttentionLayer(cuda = False, nhid=4)
att_layer.forward(temp_tensor, 3)

([tensor([[[0.6762, 0.8613, 0.2314, 0.3677],
           [0.6751, 0.8532, 0.3005, 0.7811],
           [0.1339, 0.9738, 0.8608, 0.2614]]]),
  tensor([[[0.2963, 0.6159, 0.4210, 0.2477],
           [0.3208, 0.7002, 0.0843, 0.4030],
           [0.4868, 0.1876, 0.1901, 0.3922]]]),
  tensor([[[0.5513, 0.2155, 0.3483, 0.6454],
           [0.4497, 0.1505, 0.5163, 0.5197],
           [0.3350, 0.6046, 0.1174, 0.1685]]]),
  tensor([[[0.2771, 0.6292, 0.9451, 0.1214],
           [0.6667, 0.4044, 0.2850, 0.7135],
           [0.9281, 0.8554, 0.2401, 0.6911]]]),
  tensor([[[0.4869, 0.7268, 0.0512, 0.6582],
           [0.9918, 0.0026, 0.5540, 0.5818],
           [0.4263, 0.6786, 0.7138, 0.9145]]]),
  tensor([[[ 0.3245,  0.4134, -0.1732,  0.1117],
           [ 0.4323,  0.3662, -0.0302,  0.1905],
           [ 0.5268,  0.3781,  0.0303,  0.1209]]], grad_fn=<SqueezeBackward1>)],
 [tensor([[0.3603, 0.3439, 0.2958]], grad_fn=<SoftmaxBackward0>)])

In [17]:
class graph_gru_attention(nn.Module):
    def __init__(self, input_size, hidden_size, n_layer, bias=True):
        super(graph_gru_attention, self).__init__()

        self.hidden_size = hidden_size
        self.n_layer = n_layer
        cuda=False
        attention_width=5
        # gru weights
        self.weight_xz = []
        self.weight_hz = []
        self.weight_xr = []
        self.weight_hr = []
        self.weight_xh = []
        self.weight_hh = []
        
        self.attention_width = attention_width
        self.AttentionLayer = AttentionLayer(cuda,hidden_size)
        
        for i in range(self.n_layer):
            if i==0:
                self.weight_xz.append(GCNConv(input_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hz.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_xr.append(GCNConv(input_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hr.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_xh.append(GCNConv(input_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hh.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
            else:
                self.weight_xz.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hz.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_xr.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hr.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_xh.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
                self.weight_hh.append(GCNConv(hidden_size, hidden_size, act=lambda x:x, bias=bias))
    
    def forward(self, inp, edgidx, all_h):
        h_hat = torch.zeros(all_h[-1].size())
        for i in range(self.n_layer):
            if i==0:
                s_g = torch.sigmoid(self.weight_xz[i](inp, edgidx) + self.weight_hz[i](all_h[-1][i], edgidx))
                r_g = torch.sigmoid(self.weight_xr[i](inp, edgidx) + self.weight_hr[i](all_h[-1][i], edgidx))
                h_tilde_g = torch.tanh(self.weight_xh[i](inp, edgidx) + self.weight_hh[i](r_g * all_h[-1][i], edgidx))
                h_hat[i] = s_g * all_h[-1][i] + (1 - s_g) * h_tilde_g
                
            else:

                s_g = torch.sigmoid(self.weight_xz[i](h_hat[i-1], edgidx) + self.weight_hz[i](all_h[-1][i], edgidx))
                r_g = torch.sigmoid(self.weight_xr[i](h_hat[i-1], edgidx) + self.weight_hr[i](all_h[-1][i], edgidx))
                h_tilde_g = torch.tanh(self.weight_xh[i](h_hat[i-1], edgidx) + self.weight_hh[i](r_g * all_h[-1][i], edgidx))
                h_hat[i] = s_g * all_h[-1][i] + (1 - s_g) * h_tilde_g
                
        all_h.append(h_hat)
        
        all_h, attention_weights = self.AttentionLayer.forward(all_h, self.attention_width)
        
        out = all_h[-1]
        return all_h, out, attention_weights


In [18]:
class InnerProductDecoder(nn.Module):
    def __init__(self, act=torch.sigmoid, dropout=0.):
        super(InnerProductDecoder, self).__init__()
        
        self.act = act
        self.dropout = dropout
    
    def forward(self, inp):
        inp = F.dropout(inp, self.dropout, training=self.training)
        x = torch.transpose(inp, dim0=0, dim1=1)
        x = torch.mm(inp, x)
        return self.act(x)

## Step 5: Temporal Attention-enhanced Variational Graph Recurrent Neural Network (T-AVRNN) Model

In [19]:
class T_AVRNN(nn.Module):
    def __init__(self, x_dim, h_dim, z_dim, n_layers, eps, bias=False):
        super(T_AVRNN, self).__init__()
        
        self.x_dim = x_dim
        self.eps = eps
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.n_layers = n_layers
        
        self.phi_x = nn.Sequential(nn.Linear(x_dim, h_dim), nn.ReLU())
        self.phi_z = nn.Sequential(nn.Linear(z_dim, h_dim), nn.ReLU())
        
        self.enc = GCNConv(h_dim + h_dim, h_dim)            
        self.enc_mean = GCNConv(h_dim, z_dim, act=lambda x:x)
        self.enc_std = GCNConv(h_dim, z_dim, act=F.softplus)
        
        self.prior = nn.Sequential(nn.Linear(h_dim, h_dim), nn.ReLU())
        self.prior_mean = nn.Sequential(nn.Linear(h_dim, z_dim))
        self.prior_std = nn.Sequential(nn.Linear(h_dim, z_dim), nn.Softplus())
        
        self.rnn = graph_gru_attention(h_dim + h_dim, h_dim, n_layers, bias)
            
  
    
    def forward(self, x, edge_idx_list, adj_orig_dense_list, hidden_in=None):
        assert len(adj_orig_dense_list) == len(edge_idx_list)
        
        kld_loss = 0
        nll_loss = 0
        all_enc_mean, all_enc_std = [], []
        all_prior_mean, all_prior_std = [], []
        all_dec_t, all_z_t = [], []
        all_h = []
        all_attention_weights = []
        
        if hidden_in is None:
            h = Variable(torch.zeros(self.n_layers, x.size(1), self.h_dim))
        else:
            h = Variable(hidden_in)

        all_h.append(h)
        for t in range(x.size(0)):
            phi_x_t = self.phi_x(x[t])
            
            #encoder
            # print('phi(x) size = ', phi_x_t.size())
            # print('h[-1] size =', h[-1].size())
            enc_t = self.enc(torch.cat([phi_x_t, h[-1]], 1), edge_idx_list[t])
            # print('enc_t size = ',enc_t.size())
            enc_mean_t = self.enc_mean(enc_t, edge_idx_list[t])
            # print('enc_mean_t size = ',enc_mean_t.size())
            enc_std_t = self.enc_std(enc_t, edge_idx_list[t])
            # print('enc_std_t size = ',enc_std_t.size())
            
            #prior
            prior_t = self.prior(h[-1])
            # print('prior_t size = ',prior_t.size())
            prior_mean_t = self.prior_mean(prior_t)
            # print('prior_mean_t size = ',prior_mean_t.size())
            prior_std_t = self.prior_std(prior_t)
            # print('prior_std_t size = ',prior_std_t.size())
            
            #sampling and reparameterization
            z_t = self._reparameterized_sample(enc_mean_t, enc_std_t)
            # print('z_t size = ,',z_t.size())
            phi_z_t = self.phi_z(z_t)
            # print('phi_z_t size = ',phi_z_t.size())
            
            #decoder
            dec_t = self.dec(z_t)
            
            #recurrence
            all_h, h , attention_weights_t = self.rnn(torch.cat([phi_x_t, phi_z_t], 1), edge_idx_list[t], all_h)
            
            nnodes = adj_orig_dense_list[t].size()[0]
            enc_mean_t_sl = enc_mean_t[0:nnodes, :]
            enc_std_t_sl = enc_std_t[0:nnodes, :]
            prior_mean_t_sl = prior_mean_t[0:nnodes, :]
            prior_std_t_sl = prior_std_t[0:nnodes, :]
            dec_t_sl = dec_t[0:nnodes, 0:nnodes]
            
            #computing losses
            kld_loss += self._kld_gauss(enc_mean_t_sl, enc_std_t_sl, prior_mean_t_sl, prior_std_t_sl)
            nll_loss += self._nll_bernoulli(dec_t_sl, adj_orig_dense_list[t])
            
            # all_enc_std.append(enc_std_t_sl)
            # all_enc_mean.append(enc_mean_t_sl)
            # all_prior_mean.append(prior_mean_t_sl)
            # all_prior_std.append(prior_std_t_sl)
            # all_dec_t.append(dec_t_sl)
            all_z_t.append(z_t)
            all_attention_weights.append(attention_weights_t)
        return kld_loss, nll_loss, all_z_t , all_attention_weights
    
    def dec(self, z):
        outputs = InnerProductDecoder(act=lambda x:x)(z)
        return outputs
    
    def reset_parameters(self, stdv=1e-1):
        for weight in self.parameters():
            weight.data.normal_(0, stdv)
     
    def _init_weights(self, stdv):
        pass
    
    def _reparameterized_sample(self, mean, std):
        eps1 = torch.FloatTensor(std.size()).normal_()
        eps1 = Variable(eps1)
        return eps1.mul(std).add_(mean)
    
    def _kld_gauss(self, mean_1, std_1, mean_2, std_2):
        num_nodes = mean_1.size()[0]
        kld_element =  (2 * torch.log(std_2 + self.eps) - 2 * torch.log(std_1 + self.eps) +
                        (torch.pow(std_1 + self.eps ,2) + torch.pow(mean_1 - mean_2, 2)) / 
                        torch.pow(std_2 + self.eps ,2) - 1)
        return (0.5 / num_nodes) * torch.mean(torch.sum(kld_element, dim=1), dim=0)
    
    def _kld_gauss_zu(self, mean_in, std_in):
        num_nodes = mean_in.size()[0]
        std_log = torch.log(std_in + self.eps)
        kld_element =  torch.mean(torch.sum(1 + 2 * std_log - mean_in.pow(2) -
                                            torch.pow(torch.exp(std_log), 2), 1))
        return (-0.5 / num_nodes) * kld_element
    
    def _nll_bernoulli(self, logits, target_adj_dense):
        temp_size = target_adj_dense.size()[0]
        temp_sum = target_adj_dense.sum()
        posw = float(temp_size * temp_size - temp_sum) / temp_sum
        norm = temp_size * temp_size / float((temp_size * temp_size - temp_sum) * 2)
        nll_loss_mat = F.binary_cross_entropy_with_logits(input=logits
                                                          , target=target_adj_dense
                                                          , pos_weight=posw
                                                          , reduction='none')
        nll_loss = -1 * norm * torch.mean(nll_loss_mat, dim=[0,1])
        return - nll_loss

## Step 6: Creating Temporal Networks and Input Tensors

#### The _neural_data_ in the following cell should be replaced by your time series data (rows being individual units/neurons/channels) and columns representing the time points

In [20]:
###### Constructing Temporal Networks

neural_data = np.random.rand(900,1024)
num_timepoints = np.shape(neural_data)[1] 

# Set parameters for sliding windows
window_size = 200 ## equals to 10 seconds
step_size = int(window_size/2)  # Choose a step size (overlap between windows)


# Create a list to store dynamic functional connectivity networks

dynamic_networks = []

all_connectivity_matrices = []
for start in range(0, num_timepoints - window_size + 1, step_size):
    end = start + window_size

    # Extract the current window of data
    current_window = neural_data[:, start:end]

    
    # Calculate functional connectivity (correlation in this case)
    current_window[np.isnan(current_window)] = 0
    connectivity_matrix = np.corrcoef(current_window)
    connectivity_matrix[np.isnan(connectivity_matrix)] = 0
    


    np.fill_diagonal(connectivity_matrix, 1)
      
    all_connectivity_matrices.append(connectivity_matrix)




adj_time_list = [sp.csr_matrix(connectivity_mat) for  connectivity_mat in all_connectivity_matrices]

num_nodes = np.shape(adj_time_list[0])[0]

### Weighted Adjacency Matrix
threshold = 0.4
edge_weights = [torch.tensor(np.abs(np.where(np.abs(connectivity_matrix.toarray()) < threshold, 0, connectivity_matrix.toarray())), dtype=torch.float32) for connectivity_matrix in adj_time_list]

### Binary Adjacency Matrix
adj_orig_dense_list = [(weighted_matrix != 0).int() for weighted_matrix in edge_weights] 


### Adding the History node to the network
edge_weights_with_history_node_list = []
for weighted_matrix in edge_weights:
    # Add a row and column for the new node
    new_node_row = torch.zeros(1, num_nodes, dtype=torch.float)
    new_node_col = torch.zeros(num_nodes+1, 1, dtype=torch.float)
    
    # Concatenate the new node row and column to the existing matrix
    weighted_matrix = torch.cat([weighted_matrix, new_node_row], dim=0)
    weighted_matrix = torch.cat([weighted_matrix, new_node_col], dim=1)
    
    # Set the connections for the new node
    weighted_matrix[-1, :] = 0  # new node is not connected to all other nodes
    weighted_matrix[:, -1] = 1  # all other nodes are connected to the new node
    weighted_matrix[-1,-1] =1
    
    # Append the modified matrix to the new list
    edge_weights_with_history_node_list.append(weighted_matrix)



adj_with_history_node_list = []
for adj_matrix in adj_orig_dense_list:
    # Add a row and column for the new node
    new_node_row = torch.zeros(1, num_nodes, dtype=torch.float)
    new_node_col = torch.zeros(num_nodes+1, 1, dtype=torch.float)
    
    # Concatenate the new node row and column to the existing matrix
    adj_matrix = torch.cat([adj_matrix, new_node_row], dim=0)
    adj_matrix = torch.cat([adj_matrix, new_node_col], dim=1)
    
    # Set the connections for the new node
    adj_matrix[-1, :] = 0  # new node is not connected to all other nodes
    adj_matrix[:, -1] = 1  # all other nodes are connected to the new node
    adj_matrix[-1,-1] =1
    
    # Append the modified matrix to the new list
    adj_with_history_node_list.append(adj_matrix)


adj_time_list_with_history_node = [sp.csr_matrix(adj_matrix) for adj_matrix in adj_with_history_node_list ]


### Node attributes (if you want to use Identity Matrix, use the commented lines)
x_in = Variable(torch.stack(edge_weights_with_history_node_list))


# seq_len = len(edge_weights_with_history_node_list)
# x_in_list = []
# for i in range(0, seq_len):
#     x_temp = torch.tensor(np.eye(num_nodes+1).astype(np.float32))
#     x_in_list.append(torch.tensor(x_temp))

# x_in = Variable(torch.stack(x_in_list))


### creating edge list
all_edges = extract_edges(adj_time_list_with_history_node)
edge_idx_list = []

for i in range(len(all_edges)):
    edge_idx_list.append(torch.tensor(np.transpose(all_edges[i]), dtype=torch.long))

## Step 7: Hyperparameters

In [21]:
# Define hyperparameters
h_dim = 32
z_dim = 16
n_layers = 1
clip = 10
learning_rate = 1e-2
num_nodes = np.shape(adj_time_list[0])[0]
eps = 1e-10
x_dim = num_nodes + 1

## Step 8: Building model

In [22]:
model = T_AVRNN(x_dim, h_dim, z_dim, n_layers, eps, bias=True)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Step 9: Training

In [23]:
last_loss = 0
for k in range(1000):
    optimizer.zero_grad()
    kld_loss, nll_loss, all_z, all_att_w = model(x_in, edge_idx_list, adj_with_history_node_list)
    loss = kld_loss + nll_loss
    loss.backward()
    optimizer.step()

    
    print('Epoch:', k)
    print('KLD Loss:', kld_loss.mean().item())
    print('NLL Loss:', nll_loss.mean().item())
    print('Total Loss:', loss.mean().item())
    print('-------------------------')
    diff = np.abs(last_loss - loss.mean().item())
    last_loss = loss.mean().item()




Epoch: 0
KLD Loss: 0.008697875775396824
NLL Loss: 7.636761665344238
Total Loss: 7.6454596519470215
-------------------------
Epoch: 1
KLD Loss: 0.004716725088655949
NLL Loss: 7.0414838790893555
Total Loss: 7.046200752258301
-------------------------
Epoch: 2
KLD Loss: 0.002717156196013093
NLL Loss: 7.102411270141602
Total Loss: 7.105128288269043
-------------------------
Epoch: 3
KLD Loss: 0.0024616438895463943
NLL Loss: 7.068718433380127
Total Loss: 7.0711798667907715
-------------------------
Epoch: 4
KLD Loss: 0.002417757175862789
NLL Loss: 6.825294017791748
Total Loss: 6.827711582183838
-------------------------
Epoch: 5
KLD Loss: 0.002120734192430973
NLL Loss: 6.773260116577148
Total Loss: 6.775381088256836
-------------------------
Epoch: 6
KLD Loss: 0.0020846568513661623
NLL Loss: 6.456920146942139
Total Loss: 6.459004878997803
-------------------------
Epoch: 7
KLD Loss: 0.0026547261513769627
NLL Loss: 6.527659893035889
Total Loss: 6.5303144454956055
-------------------------
E