-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added sir model with tests comparing to nndlib
- Loading branch information
Showing
3 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import torch | ||
import torch_geometric | ||
|
||
class SIRMessagePassing(torch_geometric.nn.conv.MessagePassing): | ||
def forward(self, edge_index, infected, susceptible): | ||
return self.propagate(edge_index, x=infected, y=susceptible) | ||
def message(self, x_j, y_i): | ||
return x_j * y_i | ||
|
||
class SIR(torch.nn.Module): | ||
def __init__(self, graph, n_timesteps): | ||
""" | ||
Implements a differentiable SIR model on a graph. | ||
Arguments: | ||
graph (networkx.Graph) : a networkx graph | ||
n_timesteps (int) : the number of timesteps to run the model for | ||
""" | ||
super().__init__() | ||
self.n_timesteps = n_timesteps | ||
# convert graph from networkx to pytorch geometric | ||
self.graph = torch_geometric.utils.convert.from_networkx(graph) | ||
self.mp = SIRMessagePassing(aggr='add', node_dim=-1) | ||
|
||
def sample_bernoulli_gs(self, probs, tau=0.1): | ||
""" | ||
Samples from a Bernoulli distribution in a diferentiable way using Gumble-Softmax | ||
Arguments: | ||
probs (torch.Tensor) : a tensor of shape (n,) containing the probabilities of success for each trial | ||
tau (float) : the temperature of the Gumble-Softmax distribution | ||
""" | ||
logits = torch.vstack((probs, 1 - probs)).T.log() | ||
gs_samples = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True) | ||
return gs_samples[:,0] | ||
|
||
|
||
def forward(self, params): | ||
""" | ||
Runs the model forward | ||
Arguments: | ||
params (torch.Tensor) : a tensor of shape (3,) containing the fraction of infected, beta, and gamma | ||
""" | ||
# Initialize the parameters | ||
initial_infected = params[0] | ||
beta = params[1] | ||
gamma = params[2] | ||
n_agents = self.graph.num_nodes | ||
# Initialize the state | ||
infected = torch.zeros(n_agents) | ||
susceptible = torch.ones(n_agents) | ||
recovered = torch.zeros(n_agents) | ||
# sample the initial infected nodes | ||
probs = initial_infected * torch.ones(n_agents) | ||
new_infected = self.sample_bernoulli_gs(probs) | ||
infected += new_infected | ||
susceptible -= new_infected | ||
|
||
infected_hist = infected.sum().reshape((1,)) | ||
recovered_hist = torch.zeros((1,)) | ||
|
||
# Run the model forward | ||
for _ in range(self.n_timesteps): | ||
# Get number of infected neighbors per node, return 0 if node is not susceptible. | ||
n_infected_neighbors = self.mp(self.graph.edge_index, infected, susceptible) | ||
# each contact has a beta chance of infecting a susceptible node | ||
prob_infection = 1 - (1 - beta) ** n_infected_neighbors | ||
# sample the infected nodes | ||
new_infected = self.sample_bernoulli_gs(prob_infection) | ||
# sample recoverd people | ||
prob_recovery = gamma * infected | ||
new_recovered = self.sample_bernoulli_gs(prob_recovery) | ||
# update the state of the agents | ||
infected = infected + new_infected - new_recovered | ||
susceptible -= new_infected | ||
recovered += new_recovered | ||
infected_hist = torch.hstack((infected_hist, infected.sum().reshape(1,))) | ||
recovered_hist = torch.hstack((recovered_hist, recovered.sum().reshape(1,))) | ||
|
||
return infected_hist, recovered_hist | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
networkx==3.0 | ||
nflows==0.14 | ||
torch==2.0 | ||
torch-geometric==2.3 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import torch | ||
import numpy as np | ||
import networkx | ||
|
||
from birds.models.sir import SIR | ||
|
||
|
||
class TestSIR: | ||
""" | ||
Tests the SIR implementation. | ||
""" | ||
|
||
def test__implementation_against_nnlib(self): | ||
""" | ||
Compares results with the nnlib implementation. | ||
https://github.com/GiulioRossetti/ndlib/blob/master/docs/reference/models/epidemics/SIR.rst | ||
```python | ||
import networkx as nx | ||
import ndlib.models.ModelConfig as mc | ||
import ndlib.models.epidemics as ep | ||
import numpy as np | ||
# Network topology | ||
g = nx.erdos_renyi_graph(1000, 0.01) | ||
n_timesteps = 10 | ||
beta = 0.05 | ||
gamma = 0.10 | ||
fraction_infected = 0.05 | ||
# Model selection | ||
model = ep.SIRModel(g) | ||
# Model Configuration | ||
cfg = mc.Configuration() | ||
cfg.add_model_parameter("fraction_infected", fraction_infected) | ||
cfg.add_model_parameter('beta', beta) | ||
cfg.add_model_parameter('gamma', gamma) | ||
model.set_initial_status(cfg) | ||
# Simulation execution | ||
status = np.array(list(model.status.values())) | ||
infected = (status == 1).sum() | ||
recovered = 0 | ||
susceptible = N - infected | ||
iterations = model.iteration_bunch(n_timesteps) | ||
susc = [] | ||
inf = [] | ||
rec = [] | ||
for iteration in iterations: | ||
status = np.array(list(iteration["status"].values())) | ||
new_infected = (status == 1).sum() | ||
new_recovered = (status == 2).sum() | ||
infected = infected + new_infected - new_recovered | ||
recovered += new_recovered | ||
susceptible -= new_infected | ||
susc.append(susceptible) | ||
inf.append(infected) | ||
rec.append(recovered) | ||
print(inf) | ||
print(rec) | ||
``` | ||
>>> [50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407] | ||
>>> [0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163] | ||
""" | ||
N = 1000 | ||
beta = 0.05 | ||
gamma = 0.10 | ||
fraction_infected = 0.05 | ||
n_timesteps = 10 # we do one more... | ||
graph = networkx.erdos_renyi_graph(N, 0.01) | ||
model = SIR(graph=graph, n_timesteps=n_timesteps) | ||
infected, recovered = model( | ||
torch.tensor([fraction_infected, beta, gamma]) | ||
) # fraction infected, beta, and gamma | ||
exp_infected = torch.tensor( | ||
[50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407], dtype=torch.float | ||
) | ||
exp_recovered = torch.tensor( | ||
[0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163], dtype=torch.float | ||
) | ||
# check initial infected fraction | ||
assert np.isclose(infected[0], 0.05 * N, rtol=0.3, atol=2) | ||
# values from nndlib run | ||
assert torch.allclose( | ||
infected[:-1], exp_infected[1:], rtol=1.0 | ||
) # check is close within a factor of 2... seed is tricky here | ||
# they do recovery differently... | ||
assert torch.allclose(recovered[:-1], exp_recovered[1:], rtol=1.0, atol=5) | ||
|
||
def test__gradient_propagates(self): | ||
""" | ||
Checks that the gradient propagates through the model. | ||
""" | ||
N = 1000 | ||
beta = 0.05 | ||
gamma = 0.10 | ||
fraction_infected = 0.05 | ||
n_timesteps = 10 | ||
graph = networkx.erdos_renyi_graph(N, 0.01) | ||
model = SIR(graph=graph, n_timesteps=n_timesteps) | ||
probs = torch.tensor([fraction_infected, beta, gamma], requires_grad=True) | ||
infected, _ = model(probs) | ||
infected.sum().backward() | ||
assert probs.grad is not None |