Skip to content

Commit

Permalink
added sir model with tests comparing to nndlib
Browse files Browse the repository at this point in the history
  • Loading branch information
arnauqb committed May 12, 2023
1 parent 7fac6c6 commit f69bccf
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 0 deletions.
82 changes: 82 additions & 0 deletions birds/models/sir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch_geometric

class SIRMessagePassing(torch_geometric.nn.conv.MessagePassing):
def forward(self, edge_index, infected, susceptible):
return self.propagate(edge_index, x=infected, y=susceptible)
def message(self, x_j, y_i):
return x_j * y_i

class SIR(torch.nn.Module):
def __init__(self, graph, n_timesteps):
"""
Implements a differentiable SIR model on a graph.
Arguments:
graph (networkx.Graph) : a networkx graph
n_timesteps (int) : the number of timesteps to run the model for
"""
super().__init__()
self.n_timesteps = n_timesteps
# convert graph from networkx to pytorch geometric
self.graph = torch_geometric.utils.convert.from_networkx(graph)
self.mp = SIRMessagePassing(aggr='add', node_dim=-1)

def sample_bernoulli_gs(self, probs, tau=0.1):
"""
Samples from a Bernoulli distribution in a diferentiable way using Gumble-Softmax
Arguments:
probs (torch.Tensor) : a tensor of shape (n,) containing the probabilities of success for each trial
tau (float) : the temperature of the Gumble-Softmax distribution
"""
logits = torch.vstack((probs, 1 - probs)).T.log()
gs_samples = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True)
return gs_samples[:,0]


def forward(self, params):
"""
Runs the model forward
Arguments:
params (torch.Tensor) : a tensor of shape (3,) containing the fraction of infected, beta, and gamma
"""
# Initialize the parameters
initial_infected = params[0]
beta = params[1]
gamma = params[2]
n_agents = self.graph.num_nodes
# Initialize the state
infected = torch.zeros(n_agents)
susceptible = torch.ones(n_agents)
recovered = torch.zeros(n_agents)
# sample the initial infected nodes
probs = initial_infected * torch.ones(n_agents)
new_infected = self.sample_bernoulli_gs(probs)
infected += new_infected
susceptible -= new_infected

infected_hist = infected.sum().reshape((1,))
recovered_hist = torch.zeros((1,))

# Run the model forward
for _ in range(self.n_timesteps):
# Get number of infected neighbors per node, return 0 if node is not susceptible.
n_infected_neighbors = self.mp(self.graph.edge_index, infected, susceptible)
# each contact has a beta chance of infecting a susceptible node
prob_infection = 1 - (1 - beta) ** n_infected_neighbors
# sample the infected nodes
new_infected = self.sample_bernoulli_gs(prob_infection)
# sample recoverd people
prob_recovery = gamma * infected
new_recovered = self.sample_bernoulli_gs(prob_recovery)
# update the state of the agents
infected = infected + new_infected - new_recovered
susceptible -= new_infected
recovered += new_recovered
infected_hist = torch.hstack((infected_hist, infected.sum().reshape(1,)))
recovered_hist = torch.hstack((recovered_hist, recovered.sum().reshape(1,)))

return infected_hist, recovered_hist

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
networkx==3.0
nflows==0.14
torch==2.0
torch-geometric==2.3
Expand Down
109 changes: 109 additions & 0 deletions test/models/test_sir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import numpy as np
import networkx

from birds.models.sir import SIR


class TestSIR:
"""
Tests the SIR implementation.
"""

def test__implementation_against_nnlib(self):
"""
Compares results with the nnlib implementation.
https://github.com/GiulioRossetti/ndlib/blob/master/docs/reference/models/epidemics/SIR.rst
```python
import networkx as nx
import ndlib.models.ModelConfig as mc
import ndlib.models.epidemics as ep
import numpy as np
# Network topology
g = nx.erdos_renyi_graph(1000, 0.01)
n_timesteps = 10
beta = 0.05
gamma = 0.10
fraction_infected = 0.05
# Model selection
model = ep.SIRModel(g)
# Model Configuration
cfg = mc.Configuration()
cfg.add_model_parameter("fraction_infected", fraction_infected)
cfg.add_model_parameter('beta', beta)
cfg.add_model_parameter('gamma', gamma)
model.set_initial_status(cfg)
# Simulation execution
status = np.array(list(model.status.values()))
infected = (status == 1).sum()
recovered = 0
susceptible = N - infected
iterations = model.iteration_bunch(n_timesteps)
susc = []
inf = []
rec = []
for iteration in iterations:
status = np.array(list(iteration["status"].values()))
new_infected = (status == 1).sum()
new_recovered = (status == 2).sum()
infected = infected + new_infected - new_recovered
recovered += new_recovered
susceptible -= new_infected
susc.append(susceptible)
inf.append(infected)
rec.append(recovered)
print(inf)
print(rec)
```
>>> [50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407]
>>> [0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163]
"""
N = 1000
beta = 0.05
gamma = 0.10
fraction_infected = 0.05
n_timesteps = 10 # we do one more...
graph = networkx.erdos_renyi_graph(N, 0.01)
model = SIR(graph=graph, n_timesteps=n_timesteps)
infected, recovered = model(
torch.tensor([fraction_infected, beta, gamma])
) # fraction infected, beta, and gamma
exp_infected = torch.tensor(
[50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407], dtype=torch.float
)
exp_recovered = torch.tensor(
[0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163], dtype=torch.float
)
# check initial infected fraction
assert np.isclose(infected[0], 0.05 * N, rtol=0.3, atol=2)
# values from nndlib run
assert torch.allclose(
infected[:-1], exp_infected[1:], rtol=1.0
) # check is close within a factor of 2... seed is tricky here
# they do recovery differently...
assert torch.allclose(recovered[:-1], exp_recovered[1:], rtol=1.0, atol=5)

def test__gradient_propagates(self):
"""
Checks that the gradient propagates through the model.
"""
N = 1000
beta = 0.05
gamma = 0.10
fraction_infected = 0.05
n_timesteps = 10
graph = networkx.erdos_renyi_graph(N, 0.01)
model = SIR(graph=graph, n_timesteps=n_timesteps)
probs = torch.tensor([fraction_infected, beta, gamma], requires_grad=True)
infected, _ = model(probs)
infected.sum().backward()
assert probs.grad is not None

0 comments on commit f69bccf

Please sign in to comment.