In [1]:
import torch
from torch_geometric.data import Data
import numpy as np
import networkx as nx
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import matplotlib.pyplot as plt
import scipy.special as SS
import scipy.stats as SSA

In [3]:
WN = np.loadtxt('../W_avg.csv')
pop = np.loadtxt('../pop_new.csv')

In [159]:
class EpidemicSimulator(MessagePassing):
    def __init__(self, r, p, weight, max_time_step):
        super(EpidemicSimulator, self).__init__(aggr='add')
        self.r = r ## to be estimated
#         self.p = p
        self.p_prime = 1-p ## depends on R0
        self.max_time_step = max_time_step
        self.Z = 3 # latent period
        self.Zb = 1 # scale parameter for Z
        self.D = 5 # infectious period
        self.Db = 1 # scale parameter for beta 
        self.weight = torch.Tensor(weight) 
        self.offspring = []
    
    def forward(self, x, edge_index, edge_attr, step):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr, step=step)


    def message(self, x, edge_index, edge_attr, step):
        # x_j has shape [E, num_features]
        # edge_attr has shape [E, num_edge_features]
        # Get the new infections from x_j.
        # Initialize an empty tensor to store the results
        new_infectors = x[:, 2+step:3+step]  # the infectors at time ti
        
        temp = new_infectors.round().int()
        cases = temp.squeeze().tolist() ##case' indx is the node index
        
        results = torch.zeros_like(new_infectors)
        print(new_infectors)
        source_nodes, target_nodes = edge_index
        messages = torch.zeros_like(edge_attr)
        # Generate negative binomial for each size
        for i, size in enumerate(cases): ###i is index for the edges
            if size > 0:
                offspring_per_case = torch.tensor([2]*size)
                #torch.distributions.Categorical(self.weight).sample(sample_shape=torch.Size([size]))
                self.offspring.extend(offspring_per_case.tolist())
                temp_sum = offspring_per_case.sum()
            else:
                temp_sum = 0
            results[i] = temp_sum
        # New code for sending messages based on edge weights
        for i, num_messages in enumerate(results):
            for _ in range(int(num_messages)):
                # For each message, choose a target node based on the edge probabilities
                possible_targets = target_nodes[source_nodes == i]
                target_probabilities = edge_attr[source_nodes == i]
                chosen_target = torch.multinomial(target_probabilities, 1).item()
                target_idx = (source_nodes == i) & (target_nodes == possible_targets[chosen_target])
                messages[target_idx] += 1
        return messages.view(-1, 1)

    def update(self, aggr_out, x, step):
    # x has shape [N, num_features], it is the original node features
    # The new infections are the aggregated messages.
        new_infections = aggr_out # aggr_out has shape [N, 1], it contains the updated infections
        #### Add the effective infections to the column corresponding to the current step.####
        ## immu first
        population = x[:, 1:2]
        total_infection = torch.sum(x[:, 2:3+step], dim=1,keepdim=True) 
        rate = (population - total_infection) / population # Compute the rate.
        rate[rate<0] = 0

        new_effective_infections = new_infections*rate
        new_infections_int  = new_effective_infections.round().int()
        ### diffuse the new_infections to different times 
        inf_sizes = new_infections_int.squeeze().tolist()
        for i, inf_size_i in enumerate(inf_sizes):
            gamma_dist1 = torch.distributions.Gamma(self.Z, 1/self.Zb)
            gamma_dist2 = torch.distributions.Gamma(self.D, 1/self.Db)
            latency_p = gamma_dist1.rsample(sample_shape=torch.Size([inf_size_i]))
            infectious_p = gamma_dist2.rsample(sample_shape=torch.Size([inf_size_i]))
            v = torch.rand(inf_size_i)
            delay_days = latency_p + v * infectious_p
#             print(step, delay_days)
            for j,delay_t in enumerate(delay_days):
                t_j = (3+step+delay_t).ceil().int()
                if t_j > self.max_time_step:
                    pass
                else:
                    x[i,t_j] = x[i,t_j] + 1
        ######^^^^^^#######
        # The rest of the features remain the same.
        other_features = x[:, 2:].clone()
        # Concatenate the new infections, the population, and the other features to get the new node features.
        x_new = torch.cat([new_infections.clone(), population, other_features], dim=1)
        return x_new, self.offspring


In [165]:
def simulate_dynamics(data, R0, r, num_steps):
    p = r/(R0+r)
    xx = np.arange(0, 100, 1)  # define the range of x values the cutoff is 200
#     pmf = SSA.nbinom.pmf(xx, r, p)  # calculate the probability mass function
    pmf = SSA.nbinom.pmf(xx, r.detach().numpy(), p.detach().numpy())
    weights_n = pmf/np.sum(pmf)
#     print(weights_n)
    x = data.x
    T_len = x.shape[1]
    E_x = torch.tensor(np.zeros(300,x.shape[0],x.shape[1]))
#     E_Newcases = torch.tensor(np.zeros(300,x.shape[0],x.shape[1]))
    for e_i in range(300):
        simulator = EpidemicSimulator(r, p, weights_n, max_time_step=(T_len-1))
        for ti in range(num_steps):
            E_x[e_i,:,:], _ = simulator(x, data.edge_index, data.edge_attr, ti)
    return E_x

In [161]:
pop = np.array([1000]*4)

In [162]:
# create a graph
A = np.array([[0.25 , 0.25, 0.4, 0.1 ],
        [0.25, 0.75 , 0. , 0. ],
        [0.4, 0. , 0.55 , 0.05],
        [0.1 , 0 , 0.05, 0.85 ]])
# adjacency_matrix = torch.tensor(WN)
adjacency_matrix = torch.tensor(A)
# Get the indices where the adjacency matrix has a non-zero value
edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t()

# If your adjacency matrix has edge weights, you can get them like this:
edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]

In [163]:
xx = np.zeros((4,62)) # number of nodes, the columns of attributes
xx[:,1] = pop ## populations
## col_2 is the new infections generated by the new infectors

xx[0,2] = 10

xx = torch.tensor(xx,dtype=torch.float)

data = Data(x=xx, edge_index=edge_index, edge_attr=edge_weight)

In [164]:
NewInf_i, newcases = simulate_dynamics(data, R0=2.5, r=torch.tensor(2.0), num_steps=2)

tensor([[10.],
        [ 0.],
        [ 0.],
        [ 0.]])
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], dtype=torch.float64)
0 tensor([0.2500, 0.2500, 0.4000, 0.1000], 

In [154]:
NewInf_i

tensor([[   8., 1000.,   10.,    0.,    0.,    0.,    0.,    0.,    1.,    3.,
            1.,    1.,    0.,    1.,    1.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.],
        [   5., 1000.,    0.,    0.,    0.,    0.,    0.,    1.,    2.,    1.,
            0.,    1.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.],
    