In [172]:
import torch
from torch_geometric.data import Data
import numpy as np
import networkx as nx

In [173]:
A = np.array([[0.25 , 0.25, 0.4, 0.1 ],
        [0.25, 0.75 , 0. , 0. ],
        [0.4, 0. , 0.55 , 0.05],
        [0.1 , 0 , 0.05, 0.85 ]])

In [174]:
adjacency_matrix = torch.tensor(A)

# Get the indices where the adjacency matrix has a non-zero value
edge_index = torch.nonzero(adjacency_matrix, as_tuple=False).t()

# If your adjacency matrix has edge weights, you can get them like this:
edge_weight = adjacency_matrix[edge_index[0], edge_index[1]]

In [71]:
x = torch.tensor([[-1], [0], [1],[2]], dtype=torch.float)

In [73]:
class WeightedDynamicsSimulator(MessagePassing):
    def __init__(self):
        super(WeightedDynamicsSimulator, self).__init__(aggr='add')

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_index, edge_attr):
        # Multiply the messages by the edge weights.
        return x_j * edge_attr.view(-1, 1)

    def update(self, aggr_out):
        return aggr_out

def simulate_dynamics(data, num_steps):
    simulator = WeightedDynamicsSimulator()

#     edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes)

    x = data.x
    for _ in range(num_steps):
        x = simulator(x, edge_index, data.edge_attr)

    return x

In [160]:
xx = np.zeros((4,62))
xx[:,1] = 100
xx[2,0] = 10
xx[2,2] = 10

In [161]:
xx = torch.tensor(xx,dtype=torch.float)

In [157]:
data = Data(x=xx, edge_index=edge_index, edge_attr=edge_weight)

In [169]:
class EpidemicSimulator(MessagePassing):
    def __init__(self):
        super(EpidemicSimulator, self).__init__(aggr='add')  

#     def forward(self, x, edge_index, edge_attr):
#         return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr)
    
    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x, edge_attr=edge_attr, extra_kwargs={'x': x})


#     def message(self, x_j, edge_index, edge_attr):
#         # x_j has shape [E, num_features]
#         # edge_attr has shape [E, num_edge_features]

#         # Get the population and new infections from x_j.
#         population = x_j[:, 1:2]  # Shape: [E, 1]
#         new_infector = x_j[:, 0:1]  # Shape: [E, 1]
#         ## total infected 
#         tot_infected = np.sum(x_j[:, 2:])
#         # Compute the rate.
#         rate = 1 - tot_infected/population
#         rate[rate<0] = 0
#         # new infections 
# #         new_gen = np.random.negative_binomial(r,p,new_infector)
        
#         # Compute the messages.
#         messages = new_infections * edge_attr.view(-1, 1)

#         return messages
    
    def message(self, x_j, edge_index, edge_attr):
        # x_j has shape [E, num_features]
        # edge_attr has shape [E, num_edge_features]

        # Get the new infections from x_j.
        new_infections = x_j[:, 0:1]  # Shape: [E, 1]

        # Compute the messages.
        messages = new_infections * edge_attr.view(-1, 1)

        return messages

    def update(self, aggr_out, x):
        # aggr_out has shape [N, 1], it contains the updated infections
        # x has shape [N, num_features], it is the original node features

        # The new infections are the aggregated messages.
        new_infections = aggr_out

        # The population is the second column of the original node features.
        population = x[:, 1:2]
        
        total_infection = torch.sum(x[:, 2:], dim=1,keepdim=True)

        # Compute the rate.
        rate = (population - total_infection) / population

        # Multiply the new infections by the rate.
        new_infections = rate * new_infections

        # The rest of the features remain the same.
        other_features = x[:, 2:]

        # Concatenate the new infections, the population, and the other features to get the new node features.
        x_new = torch.cat([new_infections, population, other_features], dim=1)

        return x_new

In [150]:
def simulate_dynamics(data, num_steps):
    simulator = EpidemicSimulator()

#     edge_index, _ = add_self_loops(data.edge_index, num_nodes=data.num_nodes)

    x = data.x
    for _ in range(num_steps):
        x = simulator(x, edge_index, data.edge_attr)

    return x

In [170]:
test = simulate_dynamics(data,1)

In [171]:
test

tensor([[  4.0000, 100.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000],
        [  0.0000, 100.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   0.0000,
           0.0000,   0.0000,   0.0000,   0.0000,   0.0000,   

In [103]:
x = np.random.negative_binomial(0.1,0.5,(4,0,6,1))

In [104]:
x

array([], shape=(4, 0, 6, 1), dtype=int64)