# **SIR Agent-Based Model with Behavioral Heterogeneity**

This notebook demonstrates how to implement an agent-based SIR (Susceptible–Infected–Recovered) model using the **`sagesim`** library.

The classical **SIR model** provides a foundational framework in epidemiology for modeling the spread of infectious diseases. By extending it to an **agent-based model (ABM)**, we capture individual-level variation and interaction patterns, allowing us to simulate more realistic and heterogeneous dynamics of disease transmission.

In this implementation:

- Each **agent** represents an individual with a unique behavioral profile that influences their susceptibility to infection.
- Agents occupy one of three states:
  - **Susceptible**: Healthy, but at risk of infection.
  - **Infected**: Currently infected and capable of spreading the disease.
  - **Recovered**: No longer infectious and assumed immune.

- **Transmission** occurs through interactions between *connected agents* (i.e., agents with network edges between them). Importantly, the probability of infection upon contact is **not uniform**. It depends on each agent's **preventative behaviors**, such as:
  - Hygiene (e.g., handwashing, mask usage)
  - Social distancing
  - Vaccination status

These behaviors are encoded as 100-dimensional vectors with values in \([0, 1]\), where higher values indicate stronger protection.

##### **Transmission Probability Based on Protective Behaviors**

To model the influence of individual behaviors on transmission, we define the **effective transmission probability** between an infected agent `b` and a susceptible agent `a` as:

$$
p_{\text{eff}}(a, b) = p \cdot (1 - s_a)^\alpha \cdot (1 - s_b)^\beta
$$

where:
- $ p $ is the global base infection probability,
- $ s_a = \text{mean}(a) $ is the average protection level of the susceptible agent,
- $ s_b = \text{mean}(b) $ is the average protection level of the infected agent,
- $ \alpha > 0 $ and $ \beta > 0 $ control how strongly protective behaviors reduce susceptibility and infectiousness, respectively.

This formulation ensures that greater adherence to preventative measures by **either** agent leads to a lower likelihood of transmission.

##### **Workflow Overview**

In the remainder of this notebook, we will build a behavior-aware SIR agent-based model using the `sagesim` library. The following steps mirror the procedure outlined in the project README:

1. **Define the `SIRModel` class**  
   Subclass the `Model` class from `sagesim` to implement the SIR logic, including state transitions and transmission dynamics influenced by agents’ protective behaviors.

2. **Instantiate the model on a contact network**  
   Construct a *Watts–Strogatz* small-world network to represent social interactions. Each node corresponds to an agent, and edges define potential transmission pathways. We initialize the `SIRModel` with this network and a population of behaviorally heterogeneous agents.

3. **Run the simulation and analyze results**  
   Simulate disease spread over time and analyze how individual behaviors shape the epidemic curve and overall dynamics.


## Define the `SIRModel` class ##

### 1. **Define and Register `SIRBreed`**

As explained earlier, we begin by defining a custom breed by subclassing the `Breed` class from the `sagesim` library. In the SIR model, we use a single breed to represent the general population. This breed is implemented as the `SIRBreed` class.

- Each agent in this breed has two primary properties:

    - **`state`**: A categorical variable indicating the agent's infection status. We encode the three SIR states as:
        - `1` — Susceptible  
        - `2` — Infected  
        - `3` — Recovered

    - **`preventative_measures`**: A list of 100 floating-point values that capture the agent’s individual behavior traits—such as hygiene practices, social distancing adherence, or vaccination status—that influence their likelihood of infection.

- Recall we register step functions with specified priority levels to define how agents of a breed behave at each simulation step. In the case of the SIR model, we use a single step function that governs how an agent's state evolves over time—for instance, determining whether a susceptible agent becomes infected or an infected agent recovers.

In [None]:
from enum import Enum
from time import time
from random import random, sample
import networkx as nx
from math import mean


# import the Breed class from sagesim
from sagesim.breed import Breed

# Define the SIRState enumeration for agent states
class SIRState(Enum):
    SUSCEPTIBLE = 1
    INFECTED = 2
    RECOVERED = 3

# Define the step function to be registered for SIRBreed
def step_func(agent_ids, agent_index, globals, breeds, locations, state_adt, preventative_measures_adt):
    """
    At each simulation step, this function evaluates a subset of agents—either all agents in a serial run or a partition assigned to
    a specific rank in parallel processing—and determines whether an agent's state should change based on interactions with its neighbors
    and their respective preventative behaviors.

    Parameters:
    ----------
    agent_ids : list[int]
        The adt that contains the IDs of all agents assigned to the current rank, and their neighbors.
    agent_index : int
        Index of the agent being evaluated in the agent_ids list. 
    globals : list
        Global parameters; 
        the zero-th global parameter is by default the simulation tick, 
        the first item will be our infection probability $p$.
        the second item is the susceptibility reduction strength $alpha$.
        the third item is the infectiousness reduction strength $beta$.
    breeds : list
        List of breed objects (unused here as we only have one type of breed, but must passed for interface compatibility).
    locations : list[list[int]]
        Adjacency list specifying neighbors for each agent.
    state_adt : list[int]
        List of current state of each agent.
    preventative_measures_adt : list[list[float]]
        List of vectors representing each agent’s preventative behaviors. 
    Returns:
    -------
    None
        The function updates the `states` list in-place if an agent becomes infected.
    """

    # Retrieve this agent’s neighbors
    neighbor_ids = locations[agent_index]

    # Skip step if the agent is not susceptible, i.e., if it is already infected or recovered.
    if int(state_adt[agent_index]) != 1:
        return

    # Get global infection probability
    p = globals[1]
    # Get global susceptibility reduction strength
    alpha = globals[2]
    # Get global infectiousness reduction strength
    beta = globals[3]

    # Preventative measures of the current (susceptible) agent
    agent_pm = preventative_measures_adt[agent_index]
    agent_pm_mean = mean(agent_pm)

    # Draw a random number for probabilistic infection check
    rand_val = random()

    # Loop over all neighbors
    for neighbor_id in neighbor_ids:
        # Find index of neighbor in agent_ids list
        try:
            neighbor_index = agent_ids.index(neighbor_id)
        except ValueError:
            continue  # skip if neighbor ID not found (should not happen)

        # Check if the neighbor is infected
        if int(state_adt[neighbor_index]) == 2:
            neighbor_pm = preventative_measures_adt[neighbor_index]
            neighbor_pm_mean = mean(neighbor_pm)
            p_eff = p * (1 - agent_pm_mean)**alpha * (1 - neighbor_pm_mean)**beta
            # Infection probability depends on 1 - joint safety
            if rand_val < p_eff:
                state_adt[agent_index] = 2  # Agent becomes infected
                return  # No need to continue checking other neighbors


Now we are ready to define the `SIRBreed` class, which inherits from the base `Breed` class.

- Each breed must be given a unique name. This name is used internally by `sagesim` to identify and manage the breed.
- When calling `register_property`, you can optionally specify an initial value to assign to all agents of this breed. If not provided, the default is `nan`. You may also override this value when creating individual agents.
- The breed's behavior is defined by registering one or more step functions using `register_step_function()`. You can specify a priority for each function, with the default being `0`.

In [None]:

class SIRBreed(Breed):
    """
    SIRBreed class the SIR model.
    Inherits from the Breed class in the sagesim library.
    """

    def __init__(self) -> None:
        name = "SIR"
        super().__init__(name) 
        # Register properties for the breed
        self.register_property("state", SIRState.SUSCEPTIBLE.value) 
        self.register_property("preventative_measures", [random.random() for _ in range(100)])
        # Register the step function
        self.register_step_func(step_func)


As each agent 

### 3. **Create and Connect Agents**


With the `SIRBreed` and the reduction function `reduce_agent_data_tensors_` defined, we're now ready to initialize the `SIRModel`. The next step is to implement a class method that creates agents and establishes connections between them.

- **Creating agents**: Use the model method `create_agent_of_breed()`, which takes the breed object along with user-defined breed properties (such as `state` and `preventative_measures`). It returns the unique ID of the newly created agent.

- **Connecting agents**: Use `self.get_space().connect_agents()` to connect two agents by their IDs. This establishes a neighbor relationship between them in the simulation space.

This model also includes a **global property**: the base infection probability. This represents the baseline probability that a susceptible agent becomes infected when in contact with an infected neighbor. The final infection probability is adjusted based on the `preventative_measures` characteristics of both the agent and its neighbor.

With these components in place, we’re ready to define the `SIRModel` class tailored to simulate the dynamics of the SIR agent-based model.

In [None]:
from sagesim.model import Model
from sagesim.space import NetworkSpace # hopefully, we can avoid this import


class SIRModel(Model):
    """
    SIRModel class for the SIR model.
    Inherits from the Model class in the sagesim library.
    """

    def __init__(self, p_infection=1.0) -> None:
        space = NetworkSpace()
        super().__init__(space)
        self._sir_breed = SIRBreed()

        # Register the breed
        self.register_breed(breed=self._sir_breed)

        # register user-defined global properties
        self.register_global_property("p_infection", p_infection)
        self.register_global_property("alpha", 1.0)
        self.register_global_property("beta", 1.0)

    # create_agent method takes user-defined properties, that is, the state and preventative_measures, to create an agent
    def create_agent(self, state, preventative_measures):
        agent_id = self.create_agent_of_breed(
            self._sir_breed, state=state, preventative_measures=preventative_measures
        )
        self.get_space().add_agent(agent_id)
        return agent_id

    def connect_agents(self, agent_0, agent_1):
        self.get_space().connect_agents(agent_0, agent_1)


## **Instantiate the SIR Model**  

For this notebook, we use the *Watts–Strogatz* small-world network for mimic of real-world contact networks, this graph captures key features of real-world contact networks, such as high clustering and short path lengths, making it well-suited for modeling infectious disease spread. Each node in the graph will correspond to an agent, and edges will define neighbor relationships used for potential transmission. 

Once the network is created, pass it to the model's constructor along with any required parameters (e.g., initial infection probability, number of agents, etc.).

Create a SIR model with agents in a small-world network topology.

Parameters:
- model (SIRModel): An instance of the SIR model.
- num_agents (int): Total number of agents.
- num_init_connections (int): Each agent is connected to this many neighbors.
- num_infected (int): Number of initially infected agents.

Returns:
- SIRModel: Initialized model with agents and connections.




In [None]:
model = SIRModel()

# MPI environment setup
comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
worker = comm.Get_rank()

num_agents = 1000
num_init_connections = 6
num_nodes = 1  # Logical nodes, can be used for partitioning


def generate_small_world_network(n, k, p):
    """
    Generate a small-world network using the Watts-Strogatz model.

    Parameters:
    - n (int): Number of nodes (agents).
    - k (int): Each node is connected to its k nearest neighbors.
    - p (float): Probability of rewiring an edge (introduces randomness).

    Returns:
    - networkx.Graph: Generated network.
    """
    return nx.watts_strogatz_graph(n, k, p)


network = generate_small_world_network(num_agents, num_init_connections, 0.2)

for n in network.nodes:
    preventative_measures = [random() for _ in range(100)]
    model.create_agent(SIRState.SUSCEPTIBLE.value, preventative_measures)

for n in sample(sorted(network.nodes), num_infected):
    model.set_agent_property_value(n, "state", SIRState.INFECTED.value)

for edge in network.edges:
    model.connect_agents(edge[0], edge[1])


## 👥 Create and Initialize Agents
Agents are placed in the network and randomly assigned an initial state.


In this simple SIR model, we define only **one breed**, representing the general population. 

## ⚙️ Define Simulation Parameters
You can modify these values to run different configurations.

In [14]:
num_agents = 1000
num_init_connections = 6
num_nodes = 1  # Logical nodes, can be used for partitioning

## 🛠️ Set up the SIR Model

In [None]:
model = SIRModel()
model.setup(use_gpu=True)  # Enables GPU acceleration if available

AttributeError: 'SIRModel' object has no attribute 'register_reduce_function'

## 🧱 Build the Agent Network

In [None]:
model_creation_start = time()

model = generate_small_world_of_agents(
    model,
    num_agents,
    num_init_connections,
    int(0.1 * num_agents),  # 10% initially infected
)

model_creation_end = time()
model_creation_duration = model_creation_end - model_creation_start
print(f"Model creation took {model_creation_duration:.2f} seconds.")

## ▶️ Run the Simulation

In [None]:
simulate_start = time()

model.simulate(num_ticks=10, sync_workers_every_n_ticks=1)

simulate_end = time()
simulate_duration = simulate_end - simulate_start
print(f"Simulation took {simulate_duration:.2f} seconds.")

## 📊 Collect Final States

In [None]:
if worker == 0:
    result = [
        SIRState(model.get_agent_property_value(agent_id, property_name="state"))
        for agent_id in range(num_agents)
        if model.get_agent_property_value(agent_id, property_name="state") is not None
    ]
    print(f"Final state distribution: {[str(state) for state in result[:10]]}...")

## 💾 Save Execution Metrics

In [None]:
import csv

if worker == 0:
    with open("execution_times.csv", "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            num_agents,
            num_init_connections,
            num_nodes,
            num_workers,
            model_creation_duration,
            simulate_duration
        ])
    print("Execution time written to 'execution_times.csv'.")