In [1]:
import torch
from torch_geometric.data import Data
import numpy as np
import networkx as nx
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import matplotlib.pyplot as plt
import scipy.special as SS
import scipy.stats as SSA

In [2]:
## create a graph
A = np.array([[0.25 , 0.25, 0.4, 0.1 ],
        [0.25, 0.75 , 0. , 0. ],
        [0.4, 0. , 0.55 , 0.05],
        [0.1 , 0 , 0.05, 0.85 ]])
adjacency_matrix = torch.tensor(A)

# Get the indices where the adjacency matrix has a non-zero value
edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t()

# If your adjacency matrix has edge weights, you can get them like this:
edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]

In [62]:
def torch_negative_binomial(n, p, size):
    # Generate gamma distribution
    gamma = torch.distributions.Gamma(n, p/(1 - p)).sample(sample_shape=torch.Size([size]))
    # Generate Poisson distribution
    return torch.distributions.Poisson(gamma).sample()

class EpidemicSimulator(MessagePassing):
    def __init__(self, r, p, max_time_step):
        super(EpidemicSimulator, self).__init__(aggr='add')
        self.r = r
        self.p = p  
        self.max_time_step = max_time_step
        print('self.max_time_step', self.max_time_step)

    Z = 3 # latent period
    Zb = 1 # scale parameter for Z
    D = 5 # infectious period
    Db = 1 # scale parameter for b
    
    def forward(self, x, edge_index, edge_attr, step):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr, step=step)

    def message(self, x_j, edge_index, edge_attr):
        # x_j has shape [E, num_features]
        # edge_attr has shape [E, num_edge_features]
        # Get the new infections from x_j.
        new_infections = x_j[:, 0:1]  # Shape: [E, 1]
        # Compute the messages.
        messages = new_infections * edge_attr.view(-1, 1)
        return messages

    def update(self, aggr_out, x, step):
#         print('timestep', step)
        # x has shape [N, num_features], it is the original node features
        # The new infections are the aggregated messages.
        new_infections = aggr_out # aggr_out has shape [N, 1], it contains the updated infections
        #### Add the effective infections to the column corresponding to the current step.####
#         print('new_infections',new_infections)
        ## diffuse the new_infections to different times 
        new_infections_int  = new_infections.round().int()
        inf_sizes = new_infections_int.squeeze().tolist()
#         print('new_infections',new_infections_int)
        for i, inf_size_i in enumerate(inf_sizes):
            gamma_dist1 = torch.distributions.Gamma(Z, 1/Zb)
            gamma_dist2 = torch.distributions.Gamma(D, 1/Db)
            latency_p = gamma_dist1.sample(sample_shape=torch.Size([inf_size_i]))
            infectious_p = gamma_dist2.sample(sample_shape=torch.Size([inf_size_i]))
            v = torch.rand(inf_size_i)
            delay_days = latency_p + v * infectious_p
#             print('!!!!!!',i, delay_days)
            for j,delay_t in enumerate(delay_days):
                t_j = (1+step+delay_t).ceil().int()
#                 print('individual',i,t_j)
                if t_j > self.max_time_step:
                    pass
                else:
                    x[i,t_j] = x[i,t_j] + 1
#         print(x)
        ##generate new infections based on the current time infectors
        population = x[:, 1:2]
        new_generation = x[:, 2+step:3+step] ## the infectors at time ti
        total_infection = torch.sum(x[:, 2:3+step], dim=1,keepdim=True) 
        rate = (population - total_infection) / population # Compute the rate.
        rate[rate<0] = 0
        
        temp = new_generation.round().int()
        sizes = temp.squeeze().tolist()
#         print('rate',rate,'infectors', sizes)
        # Initialize an empty tensor to store the results
        results = torch.empty_like(new_generation)
        # Generate negative binomial for each size
        for i, size in enumerate(sizes):
            result = torch_negative_binomial(self.r, self.p, size)
            temp_sum = result.sum()
#             print('raw_offspring',result)
            effective_infections = (rate[i] * temp_sum)
            results[i] = effective_infections
#             print('after rate',effective_infections)         
#         print('each node new offspring after rate,results',results)
        new_infections = results
        ######^^^^^^#######
        # The rest of the features remain the same.
        other_features = x[:, 2:]
        # Concatenate the new infections, the population, and the other features to get the new node features.
        x_new = torch.cat([new_infections, population, other_features], dim=1)
#         print('^___________________________^')
        return x_new

def simulate_dynamics(data, R0, r, num_steps):
    p = r/(R0+r)   
    simulator = EpidemicSimulator(r,p,61)
    x = data.x
    for ti in range(num_steps):
        x = simulator(x, edge_index, data.edge_attr,ti)
    return x

In [65]:
## node characteristics
xx = np.zeros((4,62)) # number of nodes, the columns of attributes
# xx[2,2] = 10 ## the new infectors
xx[:,1] = 1000 ## populations
## col_2 is the new infections generated by the new infectors
xx[2,2] = 10 ## the new infections at time 0 
xx = torch.tensor(xx,dtype=torch.float)
## data structures
data = Data(x=xx, edge_index=edge_index, edge_attr=edge_weight)
test = simulate_dynamics(data, R0=2.5, r=1, num_steps=60)

self.max_time_step 61
^___________________________^
individual 0 tensor(7, dtype=torch.int32)
individual 0 tensor(13, dtype=torch.int32)
individual 0 tensor(7, dtype=torch.int32)
individual 0 tensor(6, dtype=torch.int32)
individual 0 tensor(6, dtype=torch.int32)
individual 0 tensor(9, dtype=torch.int32)
individual 0 tensor(9, dtype=torch.int32)
individual 0 tensor(11, dtype=torch.int32)
individual 2 tensor(7, dtype=torch.int32)
individual 2 tensor(7, dtype=torch.int32)
individual 2 tensor(4, dtype=torch.int32)
individual 2 tensor(14, dtype=torch.int32)
individual 2 tensor(6, dtype=torch.int32)
individual 2 tensor(8, dtype=torch.int32)
individual 2 tensor(8, dtype=torch.int32)
individual 2 tensor(16, dtype=torch.int32)
individual 2 tensor(9, dtype=torch.int32)
individual 2 tensor(8, dtype=torch.int32)
individual 2 tensor(7, dtype=torch.int32)
individual 3 tensor(8, dtype=torch.int32)
^___________________________^
^___________________________^
individual 0 tensor(9, dtype=torch.int32)
in

In [66]:
test

tensor([[   0., 1000.,    0.,    0.,    0.,    0.,    2.,    2.,    0.,    3.,
            1.,    4.,    1.,    3.,    3.,    7.,    2.,    8.,    4.,    9.,
            9.,   11.,   14.,   10.,   15.,   21.,   20.,   32.,   28.,   26.,
           32.,   41.,   46.,   72.,   48.,   59.,   67.,   66.,   63.,   63.,
           60.,   55.,   50.,   55.,   47.,   35.,   25.,   26.,   23.,   11.,
           12.,    4.,    4.,    3.,    3.,    1.,    0.,    0.,    0.,    0.,
            0.,    0.],
        [   0., 1000.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,
            0.,    0.,    2.,    1.,    3.,    0.,    1.,    2.,    2.,    5.,
            1.,    4.,    3.,   10.,    8.,    9.,   12.,   12.,   20.,   14.,
           20.,   23.,   27.,   28.,   42.,   48.,   57.,   44.,   54.,   49.,
           78.,   69.,   61.,   71.,   74.,   59.,   50.,   55.,   39.,   22.,
           23.,    9.,   15.,    9.,    3.,    0.,    1.,    1.,    0.,    0.,
            0.,    0.],
    