In [None]:
import numpy as np
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
%run DataSEIR.ipynb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# product of all adjacent edges
def prod(nodes):
    return {'prod' : torch.prod(nodes.mailbox['m'], 1, dtype=torch.double)}

import dgl.function as fn

# dmp layer propergation
class DMP(nn.Module):
    def __init__(self, lamb, mu):
        super(DMP, self).__init__()
        self.factor = (1-mu)*(1-lamb)
        self.lamb = lamb
        self.mu = mu
    
    
    
    def forward(self, g, iterations, edge_index):
        """Forward computation

        Parameters
        ----------
        g : Graph
            The input graph.
        g has edge data 'phi' and 'theta', 'p'
        g has node data 'p_zero' & 'prod'
       """
        self.inverse, self.perm = inverse_finder(edge_index.cpu())
        for it in range(iterations):
        #with g.local_scope():
        
            # p_old is p(t-1)
            g.edata['p_old'] = g.edata['p']
            
            #update theta(t)
            g.edata['theta'] = g.edata['theta']-self.lamb*g.edata['phi']
            
            # message passing API, updates node attibute 'prod'
            g.update_all(message_func=fn.copy_e('theta', 'm'), reduce_func=prod)
            tprod = g.ndata['prod']
            tprod[tprod==0] = 1
            g.ndata['prod'] = tprod
            
            
            g.apply_edges(fn.copy_u('prod', 'own_prod')) #prod assuming there is no inverse path
            g.apply_edges(fn.v_div_e('prod', 'theta', 'inverse_reduced_prod')) #reduced inverse prob assuming there is a inverse path
            
            own_prod = g.edata['own_prod']
            inverse_reduced_prod = g.edata['inverse_reduced_prod']
            
            # replace prod with reduced term if inverse exists
            own_prod[self.inverse] = inverse_reduced_prod[self.perm][self.inverse]
            

            g.edata['p'] = own_prod
            
            # factor P(t=0) is added
            g.apply_edges(fn.u_mul_e('p_zero', 'p', 'p'))
            
            # update phi(t)
            g.edata['phi'] = self.factor*g.edata['phi']-(g.edata['p']-g.edata['p_old'])
            
            # marginal probabilities at t
            g.ndata['ps'] = g.ndata['p_zero']*g.ndata['prod']
            g.ndata['pr'] = g.ndata['pr']+self.mu*g.ndata['pi']
            g.ndata['pi'] = 1-g.ndata['ps']-g.ndata['pr']
        return g

In [None]:
# find where there are inverse (undirected) edges in the graph and how to permute them onto each other
def inverse_finder(edge_index):
    perm = []
    inverse = []
    connections = edge_index.numpy().T
    for i, edge in enumerate(connections):
        #print(edge[::-1])
        p = np.where((edge[::-1] == connections).sum(axis=1)==2)[0]
        perm.append(p[0] if p.shape[0]==1 else i)
        inverse.append(p.shape[0]==1)
    return torch.tensor(inverse, dtype=torch.bool), torch.tensor(perm, dtype=torch.long)

In [None]:
# initialize message passing network / layer
class initialize(nn.Module):
    def __init__(self):
        super(initialize, self).__init__()
    
    
    
    def forward(self, edge_index, infected, num_nodes):
        """Forward computation

        Parameters
        ----------
        infected: list of infected nodes
        edge_index: edge index in COO format
       """
        g = dgl.graph((edge_index[0], edge_index[1]), num_nodes=num_nodes).to(device)
        num_edges = g.num_edges()
        
        edge_index = edge_index.cpu().numpy()
        theta0 = torch.ones(num_edges).double().to(device)
        
        phi0 = torch.where(torch.from_numpy(np.isin(edge_index[0], infected)), 1, 0).double().to(device)
        
        p_zero = torch.ones(num_nodes, dtype=torch.double).to(device)
        p_zero[infected] = 0
        
        p0 = 1-phi0
        
        # attributes at t=0
        g.edata['theta'] = theta0
        g.edata['phi'] = phi0
        g.ndata['p_zero'] = p_zero
        g.edata['p'] = p0
        
        # marginal probabilities at t=0
        g.ndata['ps'] = (p_zero).double()
        g.ndata['pi'] = (1-p_zero).double()
        g.ndata['pr'] = (p_zero*0).double()
        
        return g

In [None]:
# predict likelihood
class likelihood(nn.Module):
    def __init__(self):
        super(likelihood, self).__init__()
        """
        snapshot: list of states from the snapshot after t timesteps
        0: S
        1: I
        2: R
        """
    
    
    def forward(self, g, batch_mask, snapshot, time_steps, num_nodes):
        """Forward computation

        Parameters
        ----------
        infected: list of infected nodes
        edge_index: edge index in COO format
       """
        snapshot = snapshot.reshape(-1)
        likelihood = np.empty(0, dtype=np.double)
        batch_mask = batch_mask.numpy()
        for batch in range(1,batch_mask[-1]+1):
            ps = g.ndata['ps'][batch_mask==batch]
            pi = g.ndata['pi'][batch_mask==batch]
            pr = g.ndata['pr'][batch_mask==batch]
            
            p_all = torch.stack((ps, pi, pr))*2.5
            runner = torch.arange(ps.shape[0])
            
            p_total = torch.prod(p_all[snapshot.long(), runner.long()], dtype=torch.double)
            likelihood = np.append(likelihood,p_total.item())
        return likelihood

In [None]:
# combine all above functions and implement mini-batching, by creating a large network as N disconnected networks with different starting conditions which are propergated at the same time to vectorize
class dmp_layer():
    def __init__(self, lamb, mu):
        super(dmp_layer, self).__init__()
        self.init_dmp = initialize()
        self.dmp = DMP(lamb, mu)
        self.li = likelihood()
        
    
    def init_graph(self, edge_index, num_nodes):
        # every possible starting patient zero
        y = list(np.arange(num_nodes)+num_nodes*np.arange(num_nodes))
        edge_index = (torch.tile(edge_index, (num_nodes, 1, 1)).T+num_nodes*torch.arange(num_nodes).to(device)).T
        edge_index = edge_index.permute(1,0,2).reshape(2, -1)
        
        
        g = self.init_dmp(edge_index=edge_index, infected=y, num_nodes=num_nodes*num_nodes)
        return g, edge_index
        
        
    def likelihood_estimation(self, g, snapshot, num_nodes, time_steps):
        batch_mask = (torch.tile(torch.ones(num_nodes), (num_nodes, 1)).T*torch.arange(1,num_nodes+1)).T.reshape(-1).int()
        p = self.li(g, batch_mask, snapshot, time_steps, num_nodes)
        return p
    
    def predict(self, edge_index, snapshot, time_steps):
        num_nodes = snapshot.shape[0]
        g, larger_edge_index = self.init_graph(edge_index, num_nodes)
        g = self.dmp(g, iterations=time_steps, edge_index=larger_edge_index)
        p = self.likelihood_estimation(g, snapshot, num_nodes, time_steps)
        return p