In [1]:
import torch
from torch_geometric.data import Data
import numpy as np
import networkx as nx
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import matplotlib.pyplot as plt
import scipy.special as SS
import scipy.stats as SSA

In [2]:
WN = np.loadtxt('W_avg.csv')
pop = np.loadtxt('pop_new.csv')

In [30]:
class EpidemicSimulator(MessagePassing):
    def __init__(self, r, p, weight, max_time_step):
        super(EpidemicSimulator, self).__init__(aggr='add')
        self.r = r
#         self.p = p
        self.p_prime = 1-p
        self.max_time_step = max_time_step
        self.Z = 3 # latent period
        self.Zb = 1 # scale parameter for Z
        self.D = 5 # infectious period
        self.Db = 1 # scale parameter for beta
        self.weight = torch.Tensor(weight)
        self.offspring = []
    
    def forward(self, x, edge_index, edge_attr, step):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr, step=step)

    def message(self, x_j, edge_index, edge_attr, step):
        # x_j has shape [E, num_features]
        # edge_attr has shape [E, num_edge_features]
        # Get the new infections from x_j.
        new_infectors = x_j[:, 2+step:3+step] ## the infectors at time ti
        temp = new_infectors.round().int()
        cases = temp.squeeze().tolist()
        # Initialize an empty tensor to store the results
        results = torch.zeros_like(new_infectors)
        # Generate negative binomial for each size
        for i, size in enumerate(cases):
#             print(size)
            if size>0:
                offspring_per_case = torch.distributions.Categorical(self.weight).sample(sample_shape=torch.Size([size]))
            #torch.distributions.negative_binomial.NegativeBinomial(self.r,self.p_prime).sample(sample_shape=torch.Size([size]))
                self.offspring.extend(offspring_per_case.tolist())
                temp_sum = offspring_per_case.sum()
            else:
                temp_sum = 0
#             print(temp_sum)
            results[i] = temp_sum
        ######^^^^^^#######
        # Compute the messages.
        messages = results * edge_attr.view(-1, 1)
        return messages

    def update(self, aggr_out, x, step):
        # x has shape [N, num_features], it is the original node features
        # The new infections are the aggregated messages.
        new_infections = aggr_out # aggr_out has shape [N, 1], it contains the updated infections
        #### Add the effective infections to the column corresponding to the current step.####
        ## immu first
        population = x[:, 1:2]
        total_infection = torch.sum(x[:, 2:3+step], dim=1,keepdim=True) 
        rate = (population - total_infection) / population # Compute the rate.
        rate[rate<0] = 0
        
        new_effective_infections = new_infections*rate
        new_infections_int  = new_effective_infections.round().int()
        ### diffuse the new_infections to different times 
        inf_sizes = new_infections_int.squeeze().tolist()
        for i, inf_size_i in enumerate(inf_sizes):
            gamma_dist1 = torch.distributions.Gamma(self.Z, 1/self.Zb)
            gamma_dist2 = torch.distributions.Gamma(self.D, 1/self.Db)
            latency_p = gamma_dist1.rsample(sample_shape=torch.Size([inf_size_i]))
            infectious_p = gamma_dist2.rsample(sample_shape=torch.Size([inf_size_i]))
            v = torch.rand(inf_size_i)
            delay_days = latency_p + v * infectious_p
#             print(step, delay_days)
            for j,delay_t in enumerate(delay_days):
                t_j = (3+step+delay_t).ceil().int()
                if t_j > self.max_time_step:
                    pass
                else:
                    x[i,t_j] = x[i,t_j] + 1
        ######^^^^^^#######
        # The rest of the features remain the same.
        other_features = x[:, 2:].clone()
        # Concatenate the new infections, the population, and the other features to get the new node features.
        x_new = torch.cat([new_infections.clone(), population, other_features], dim=1)
        return x_new, self.offspring

def simulate_dynamics(data, R0, r, num_steps):
    p = r/(R0+r)  
    xx = np.arange(0,100,1)  # define the range of x values the cutoff is 200
#     pmf = SSA.nbinom.pmf(xx, r, p)  # calculate the probability mass function
    pmf = SSA.nbinom.pmf(xx, r.detach().numpy(), p.detach().numpy())
    weights_n = pmf/np.sum(pmf)
    x = data.x
    T_len = data.x.shape[1]
    simulator = EpidemicSimulator(r,p, weights_n, max_time_step=(T_len-1))
    for ti in range(num_steps):
        x,newcases = simulator(x, data.edge_index, data.edge_attr, ti)
    return x, newcases

## inference

In [5]:
pop = np.array([1000]*4)## populations

In [25]:
# create a graph
A = np.array([[0.25 , 0.25, 0.4, 0.1 ],
        [0.25, 0.75 , 0. , 0. ],
        [0.4, 0. , 0.55 , 0.05],
        [0.1 , 0 , 0.05, 0.85 ]])
# adjacency_matrix = torch.tensor(WN)
adjacency_matrix = torch.tensor(A)
# Get the indices where the adjacency matrix has a non-zero value
edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t()

# If your adjacency matrix has edge weights, you can get them like this:
edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]

T = 60
N = 4

# initial the states
xx = np.zeros((N,T+2)) # number of nodes, the columns of attributes
pop = np.array([10000]*4)
xx[:,1] = pop ## populations
## col_2 is the new infections generated by the new infectors
xx[2,2] = 10 ## the new infections at time 0 
xx = torch.tensor(xx,dtype=torch.float)

data = Data(x=xx, edge_index=edge_index, edge_attr=edge_weight)

In [36]:
data.x[:,2:]

torch.Size([4, 60])

### observation

In [27]:
NewInf_i, newcases = simulate_dynamics(data, R0=2.5, r=0.1, num_steps=T)

In [32]:
Obeserved_data = Data(x=NewInf_i, edge_index=edge_index, edge_attr=edge_weight)

In [42]:
import pyro
import torch
from pyro import poutine
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions import constraints



def model(initial_conditions, data):
    # Prior on r
    r = pyro.sample("r", pyro.distributions.Uniform(0, 10))  
    
    # Simulate dynamics with the sampled r
    simulated_data, _ = simulate_dynamics(data=initial_conditions, R0=2.5, r=r, num_steps=60)
    
    # Apply the function F
    transformed_simulated_data = simulated_data[:,2:]
    
    # Likelihood
    pyro.sample("obs", pyro.distributions.Normal(transformed_simulated_data, 0.1), obs=data.x)

def guide(data):
    # Variational parameters for r
    r_loc = pyro.param("r_loc", torch.tensor(0.))
    r_scale = pyro.param("r_scale", torch.tensor(1.), constraint=constraints.positive)
    
    # Sample r
    r = pyro.sample("r", pyro.distributions.Normal(r_loc, r_scale))
    return r

# Data: This should be the observed data
observed_data = NewInf_i[:,2:]

# SVI
svi = SVI(model, guide, Adam({"lr": 0.001}), loss=Trace_ELBO())
initial_states = data

num_steps = 5000
for step in range(num_steps):
    loss = svi.step(initial_states, observed_data)
    if step % 100 == 0:
        print(f"Step {step}, Loss {loss}")


AttributeError: 'Tensor' object has no attribute 'x'