From f69bccf946bff69b88b34737514fb2e117f1a932 Mon Sep 17 00:00:00 2001 From: Arnau Quera-Bofarull Date: Fri, 12 May 2023 18:38:19 +0100 Subject: [PATCH] added sir model with tests comparing to nndlib --- birds/models/sir.py | 82 ++++++++++++++++++++++++++++++ requirements.txt | 1 + test/models/test_sir.py | 109 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 192 insertions(+) create mode 100644 birds/models/sir.py create mode 100644 test/models/test_sir.py diff --git a/birds/models/sir.py b/birds/models/sir.py new file mode 100644 index 0000000..5f52249 --- /dev/null +++ b/birds/models/sir.py @@ -0,0 +1,82 @@ +import torch +import torch_geometric + +class SIRMessagePassing(torch_geometric.nn.conv.MessagePassing): + def forward(self, edge_index, infected, susceptible): + return self.propagate(edge_index, x=infected, y=susceptible) + def message(self, x_j, y_i): + return x_j * y_i + +class SIR(torch.nn.Module): + def __init__(self, graph, n_timesteps): + """ + Implements a differentiable SIR model on a graph. + + Arguments: + graph (networkx.Graph) : a networkx graph + n_timesteps (int) : the number of timesteps to run the model for + """ + super().__init__() + self.n_timesteps = n_timesteps + # convert graph from networkx to pytorch geometric + self.graph = torch_geometric.utils.convert.from_networkx(graph) + self.mp = SIRMessagePassing(aggr='add', node_dim=-1) + + def sample_bernoulli_gs(self, probs, tau=0.1): + """ + Samples from a Bernoulli distribution in a diferentiable way using Gumble-Softmax + + Arguments: + probs (torch.Tensor) : a tensor of shape (n,) containing the probabilities of success for each trial + tau (float) : the temperature of the Gumble-Softmax distribution + """ + logits = torch.vstack((probs, 1 - probs)).T.log() + gs_samples = torch.nn.functional.gumbel_softmax(logits, tau=tau, hard=True) + return gs_samples[:,0] + + + def forward(self, params): + """ + Runs the model forward + + Arguments: + params (torch.Tensor) : a tensor of shape (3,) containing the fraction of infected, beta, and gamma + """ + # Initialize the parameters + initial_infected = params[0] + beta = params[1] + gamma = params[2] + n_agents = self.graph.num_nodes + # Initialize the state + infected = torch.zeros(n_agents) + susceptible = torch.ones(n_agents) + recovered = torch.zeros(n_agents) + # sample the initial infected nodes + probs = initial_infected * torch.ones(n_agents) + new_infected = self.sample_bernoulli_gs(probs) + infected += new_infected + susceptible -= new_infected + + infected_hist = infected.sum().reshape((1,)) + recovered_hist = torch.zeros((1,)) + + # Run the model forward + for _ in range(self.n_timesteps): + # Get number of infected neighbors per node, return 0 if node is not susceptible. + n_infected_neighbors = self.mp(self.graph.edge_index, infected, susceptible) + # each contact has a beta chance of infecting a susceptible node + prob_infection = 1 - (1 - beta) ** n_infected_neighbors + # sample the infected nodes + new_infected = self.sample_bernoulli_gs(prob_infection) + # sample recoverd people + prob_recovery = gamma * infected + new_recovered = self.sample_bernoulli_gs(prob_recovery) + # update the state of the agents + infected = infected + new_infected - new_recovered + susceptible -= new_infected + recovered += new_recovered + infected_hist = torch.hstack((infected_hist, infected.sum().reshape(1,))) + recovered_hist = torch.hstack((recovered_hist, recovered.sum().reshape(1,))) + + return infected_hist, recovered_hist + diff --git a/requirements.txt b/requirements.txt index 42c878c..91a464e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +networkx==3.0 nflows==0.14 torch==2.0 torch-geometric==2.3 diff --git a/test/models/test_sir.py b/test/models/test_sir.py new file mode 100644 index 0000000..46dc4d3 --- /dev/null +++ b/test/models/test_sir.py @@ -0,0 +1,109 @@ +import torch +import numpy as np +import networkx + +from birds.models.sir import SIR + + +class TestSIR: + """ + Tests the SIR implementation. + """ + + def test__implementation_against_nnlib(self): + """ + Compares results with the nnlib implementation. + https://github.com/GiulioRossetti/ndlib/blob/master/docs/reference/models/epidemics/SIR.rst + ```python + import networkx as nx + import ndlib.models.ModelConfig as mc + import ndlib.models.epidemics as ep + import numpy as np + + # Network topology + g = nx.erdos_renyi_graph(1000, 0.01) + n_timesteps = 10 + beta = 0.05 + gamma = 0.10 + fraction_infected = 0.05 + + # Model selection + model = ep.SIRModel(g) + + # Model Configuration + cfg = mc.Configuration() + cfg.add_model_parameter("fraction_infected", fraction_infected) + cfg.add_model_parameter('beta', beta) + cfg.add_model_parameter('gamma', gamma) + model.set_initial_status(cfg) + + # Simulation execution + + status = np.array(list(model.status.values())) + infected = (status == 1).sum() + recovered = 0 + susceptible = N - infected + + iterations = model.iteration_bunch(n_timesteps) + + susc = [] + inf = [] + rec = [] + for iteration in iterations: + status = np.array(list(iteration["status"].values())) + new_infected = (status == 1).sum() + new_recovered = (status == 2).sum() + infected = infected + new_infected - new_recovered + recovered += new_recovered + susceptible -= new_infected + susc.append(susceptible) + inf.append(infected) + rec.append(recovered) + + print(inf) + print(rec) + ``` + >>> [50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407] + >>> [0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163] + + """ + N = 1000 + beta = 0.05 + gamma = 0.10 + fraction_infected = 0.05 + n_timesteps = 10 # we do one more... + graph = networkx.erdos_renyi_graph(N, 0.01) + model = SIR(graph=graph, n_timesteps=n_timesteps) + infected, recovered = model( + torch.tensor([fraction_infected, beta, gamma]) + ) # fraction infected, beta, and gamma + exp_infected = torch.tensor( + [50, 100, 114, 135, 174, 217, 260, 304, 329, 377, 407], dtype=torch.float + ) + exp_recovered = torch.tensor( + [0, 0, 3, 6, 11, 25, 41, 62, 93, 122, 163], dtype=torch.float + ) + # check initial infected fraction + assert np.isclose(infected[0], 0.05 * N, rtol=0.3, atol=2) + # values from nndlib run + assert torch.allclose( + infected[:-1], exp_infected[1:], rtol=1.0 + ) # check is close within a factor of 2... seed is tricky here + # they do recovery differently... + assert torch.allclose(recovered[:-1], exp_recovered[1:], rtol=1.0, atol=5) + + def test__gradient_propagates(self): + """ + Checks that the gradient propagates through the model. + """ + N = 1000 + beta = 0.05 + gamma = 0.10 + fraction_infected = 0.05 + n_timesteps = 10 + graph = networkx.erdos_renyi_graph(N, 0.01) + model = SIR(graph=graph, n_timesteps=n_timesteps) + probs = torch.tensor([fraction_infected, beta, gamma], requires_grad=True) + infected, _ = model(probs) + infected.sum().backward() + assert probs.grad is not None