# **SIR Agent-Based Model with Behavioral Heterogeneity**

This notebook demonstrates how to implement an agent-based SIR (Susceptible–Infected–Recovered) model using the **`sagesim`** library.

The classical **SIR model** provides a foundational framework in epidemiology for modeling the spread of infectious diseases. By extending it to an **agent-based model (ABM)**, we capture individual-level variation and interaction patterns, allowing us to simulate more realistic and heterogeneous dynamics of disease transmission.

In this implementation:

- Each **agent** represents an individual with a unique behavioral profile that influences their susceptibility to infection.
- Agents occupy one of three states:
  - **Susceptible**: Healthy, but at risk of infection.
  - **Infected**: Currently infected and capable of spreading the disease.
  - **Recovered**: No longer infectious and assumed immune.

- **Transmission** occurs through interactions between *connected agents* (i.e., agents with network edges between them). Importantly, the probability of infection upon contact is **not uniform**. It depends on each agent's **preventative behaviors**, such as:
  - Hygiene (e.g., handwashing, mask usage)
  - Social distancing
  - Vaccination status

These behaviors are encoded as 100-dimensional vectors with values in \([0, 1]\), where higher values indicate stronger protection.

##### **Transmission Behaviors**

The infection decision follows these key steps:

1. **Interaction Safety Evaluation**  
   For each pairwise combination of the agent’s and its neighbor’s *preventative_measures*, a score is computed by multiplying the respective values. These products are summed to produce an aggregate measure of the *absolute safety of interaction*.

2. **Normalization**  
   The summed safety score is normalized by dividing by the square of the length of the preventative measures vector. This results in a *normalized safety score* in the range [0, 1], where higher values indicate stronger mutual preventative behavior.

3. **Infection Rule**  
   If a neighbor is **infected** (`neighbor_state == 2`), the agent has a chance to become infected depending on the infection probability `p_infection`, reduced by the *normalized safety score*. Specifically, infection occurs if a sampled random number is less than `p_infection × (1 - normalized safety score)`.

4. **State Update**  
   If the infection condition is met, the agent's state is updated from **susceptible** to **infected**.


##### **Workflow Overview**

In the remainder of this notebook, we will build a behavior-aware SIR agent-based model using the `sagesim` library. The following steps mirror the procedure outlined in the project README:

1. **Define the `SIRModel` class**  
   Subclass the `Model` class from `sagesim` to implement the SIR logic, including state transitions and transmission dynamics influenced by agents’ protective behaviors.

2. **Instantiate the model on a contact network**  
   Construct a *Watts–Strogatz* small-world network to represent social interactions. Each node corresponds to an agent, and edges define potential transmission pathways. We initialize the `SIRModel` with this network and a population of behaviorally heterogeneous agents.

3. **Run the simulation and analyze results**  
   Simulate disease spread over time and analyze how individual behaviors shape the epidemic curve and overall dynamics.


## Define the `SIRModel` class ##

### 1. **Define and Register `SIRBreed`**

As explained earlier, we begin by defining a custom breed by subclassing the `Breed` class from the `sagesim` library. In the SIR model, we use a single breed to represent the general population. This breed is implemented as the `SIRBreed` class.

- Each agent in this breed has two primary properties:

    - **`state`**: A categorical variable indicating the agent's infection status. We encode the three SIR states as:
        - `1` — Susceptible  
        - `2` — Infected  
        - `3` — Recovered

    - **`preventative_measures`**: A list of 100 floating-point values that capture the agent’s individual behavior traits—such as hygiene practices, social distancing adherence, or vaccination status—that influence their likelihood of infection.

- Recall we register step functions with specified priority levels to define how agents of a breed behave at each simulation step. In the case of the SIR model, we use a single step function that governs how an agent's state evolves over time—for instance, determining whether a susceptible agent becomes infected or an infected agent recovers.

In [1]:
from enum import Enum
from time import time
from random import random, sample
import networkx as nx
from statistics import mean
from mpi4py import MPI



# import the Breed class from sagesim
from sagesim.breed import Breed

# Define the SIRState enumeration for agent states
class SIRState(Enum):
    SUSCEPTIBLE = 1
    INFECTED = 2
    RECOVERED = 3

# Define the step function to be registered for SIRBreed
def step_func(agent_ids, agent_index, globals, breeds, locations, state_adt, preventative_measures_adt):
    """
    At each simulation step, this function evaluates a subset of agents—either all agents in a serial run or a partition assigned to
    a specific rank in parallel processing—and determines whether an agent's state should change based on interactions with its neighbors
    and their respective preventative behaviors.

    Parameters:
    ----------
    agent_ids : list[int]
        The adt that contains the IDs of all agents assigned to the current rank, and their neighbors.
    agent_index : int
        Index of the agent being evaluated in the agent_ids list. 
    globals : list
        Global parameters; 
        the zero-th global parameter is by default the simulation tick, 
        the first item will be our infection probability $p$.
        the second item is the susceptibility reduction strength $alpha$.
        the third item is the infectiousness reduction strength $beta$.
    breeds : list
        List of breed objects (unused here as we only have one type of breed, but must passed for interface compatibility).
    locations : list[list[int]]
        Adjacency list specifying neighbors for each agent.
    state_adt : list[int]
        List of current state of each agent.
    preventative_measures_adt : list[list[float]]
        List of vectors representing each agent’s preventative behaviors. 
    Returns:
    -------
    None
        The function updates the `states` list in-place if an agent becomes infected.
    """
    # Get the list of neighboring agent IDs for the current agent based on network topology
    neighbor_ids = locations[agent_index]

    # Draw a random float in [0, 1) for stochastic decision-making
    rand = random()  # can replace with step_func_helper_get_random_float(rng_states, id)

    # Retrieve the global infection probability defined in the model
    p_infection = globals[1]

    # Get the preventative measures vector for the current agent
    agent_preventative_measures = preventative_measures_adt[agent_index]

    # Loop through each neighbor ID
    for i in range(len(neighbor_ids)):

        # Initialize neighbor_index to invalid value
        neighbor_index = -1

        i = 0
        while i < len(agent_ids) and agent_ids[i] != neighbor_ids[0]:
            i += 1
        if i < len(agent_ids):
            neighbor_index = i

            # Retrieve the state of the neighbor (e.g., susceptible, infected, recovered)
            neighbor_state = int(state_adt[neighbor_index])

            # Get the preventative measures vector of the neighbor
            neighbor_preventative_measures = preventative_measures_adt[neighbor_index]

            # Initialize cumulative safety score for the interaction
            abs_safety_of_interaction = 0.0

            # Calculate total safety of interaction based on pairwise product of measures
            for n in range(len(agent_preventative_measures)):
                for m in range(len(neighbor_preventative_measures)):
                    abs_safety_of_interaction += (
                        agent_preventative_measures[n] * neighbor_preventative_measures[m]
                    )

            # Normalize the safety score to be in [0, 1]
            normalized_safety_of_interaction = abs_safety_of_interaction / (
                len(agent_preventative_measures) ** 2
            )

            # If neighbor is infected and the infection condition passes, update agent’s state
            if neighbor_state == 2 and rand < p_infection * (
                1 - normalized_safety_of_interaction
            ):
                state_adt[agent_index] = 2  # Agent becomes infected


Now we are ready to define the `SIRBreed` class, which inherits from the base `Breed` class.

- Each breed must be given a unique name. This name is used internally by `sagesim` to identify and manage the breed.
- When calling `register_property`, you can optionally specify an initial value to assign to all agents of this breed. If not provided, the default is `nan`. You may also override this value when creating individual agents.
- The breed's behavior is defined by registering one or more step functions using `register_step_function()`. You can specify a priority for each function, with the default being `0`.

In [2]:

class SIRBreed(Breed):
    """
    SIRBreed class the SIR model.
    Inherits from the Breed class in the sagesim library.
    """

    def __init__(self) -> None:
        name = "SIR"
        super().__init__(name) 
        # Register properties for the breed
        self.register_property("state", SIRState.SUSCEPTIBLE.value) 
        self.register_property("preventative_measures", [random() for _ in range(100)])
        # Register the step function
        self.register_step_func(step_func)


A susceptible agent checks whether any of its infected neighbors has successfully infected it. If so, it updates only its own state (e.g., from susceptible to infected). Since agents do not modify the state of others, no data conflicts arise across ranks, and therefore, no reduce function is needed.

We are now ready to initialize the `SIRModel`:
 - register breed
 - register globle properities
 - make class methods `create_agent()` and `connect_agent()` to create and create neighborhood of agents

In [3]:
from sagesim.model import Model
from sagesim.space import NetworkSpace # hopefully, we can avoid this import


class SIRModel(Model):
    """
    SIRModel class for the SIR model.
    Inherits from the Model class in the sagesim library.
    """

    def __init__(self, p_infection=1.0) -> None:
        space = NetworkSpace()
        super().__init__(space)
        self._sir_breed = SIRBreed()

        # Register the breed
        self.register_breed(breed=self._sir_breed)

        # register user-defined global properties
        self.register_global_property("p_infection", p_infection)

    # create_agent method takes user-defined properties, that is, the state and preventative_measures, to create an agent
    def create_agent(self, state, preventative_measures):
        agent_id = self.create_agent_of_breed(
            self._sir_breed, state=state, preventative_measures=preventative_measures
        )
        self.get_space().add_agent(agent_id)
        return agent_id

    def connect_agents(self, agent_0, agent_1):
        self.get_space().connect_agents(agent_0, agent_1)


## **Instantiate the SIR Model**

1. **Generate the Contact Network**  
   Create a *Watts–Strogatz* small-world graph to approximate real-world contact patterns. This topology offers high clustering and short average path lengths, making it ideal for modeling disease transmission. Each node represents an agent, and each edge represents a potential transmission link.

2. **Create and Add Agents**  
   Instantiate the `SIRModel`, then call `create_agent()` for each node to initialize agents in the **susceptible** state with their corresponding `preventative_measures` vector.

3. **Connect Agents**  
   Use the model’s `connect_agents(agent_a, agent_b)` method to add edges between agents according to the network structure.

4. **Initialize Infections**  
   Randomly select `num_infected` agents and set their state to **infected** to seed the simulation.


In [4]:
num_agents = 1000
num_init_connections = 20
rewiring_prob = 0.1

num_infected = 10

# Generate the Contact Network
network = nx.watts_strogatz_graph(num_agents, num_init_connections, rewiring_prob)

# Instantiate the SIR Model
model = SIRModel()
model.setup(use_gpu=True)  # Enables GPU acceleration if available

# Create agents 
for n in network.nodes:
    preventative_measures = [random() for _ in range(100)] 
    model.create_agent(SIRState.SUSCEPTIBLE.value, preventative_measures)

# Connect agents in the network
for edge in network.edges:
    model.connect_agents(edge[0], edge[1])

# Infect a random sample of agents  
for n in sample(sorted(network.nodes), num_infected):
    model.set_agent_property_value(n, "state", SIRState.INFECTED.value)

## **Run the Simulation and Analyze Results**

To run the simulation, we use `model.simulate(ticks=10, sync_workers_every_n_ticks=1)` to simulate 10 time steps with synchronization after each tick.

- **Single-Rank Execution**:  
  We can first run the simulation using a single MPI rank by executing the cell directly. We also measure the execution time to establish a performance baseline.

- **Multi-Rank Execution**:  
  To evaluate performance with multiple ranks, the simulation must be executed in a parallel environment. Since this is not supported within the current notebook, we can run the same code in a separate script using the following command: `mpiexec -n 4 python tutorial_run.py`
  


In [5]:
# # MPI environment setup
comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
worker = comm.Get_rank()

# Run the simulation with 1 rank, and measure the time taken
simulate_start = time()
model.simulate(ticks = 10, sync_workers_every_n_ticks=1)
simulate_end = time()
simulate_duration = simulate_end - simulate_start
print(f"Simulation with 1 rank took {simulate_duration:.2f} seconds.")


Simulation with 1 rank took 2.04 seconds.


In [6]:
# Get the state of each agents after 10 simulation runs
result = [
    SIRState(model.get_agent_property_value(agent_id, property_name="state"))
    for agent_id in range(num_agents)
    if model.get_agent_property_value(agent_id, property_name="state") is not None
]

# count the number of infected agents
num_infected = sum(1 for state in result if state == SIRState.INFECTED)
num_recovered = sum(1 for state in result if state == SIRState.RECOVERED)
num_susceptible = sum(1 for state in result if state == SIRState.SUSCEPTIBLE)
print(f"Number of infected agents: {num_infected}")
print(f"Number of recovered agents: {num_recovered}")
print(f"Number of susceptible agents: {num_susceptible}")

Number of infected agents: 10
Number of recovered agents: 0
Number of susceptible agents: 990


In [8]:
!mpirun -n 4 python tutorial_run.py

Simulation took 1.63 seconds.
Number of infected agents: 11
Number of recovered agents: 0
Number of susceptible agents: 989
