In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import starsim as ss # starsim is the Starsim framwork
from zombie import * # zombie is a custom zombie library, see zombie.py

# Numerical librarires and utilities
import numpy as np
import pandas as pd
import sciris as sc

# Plotting libraries
import seaborn as sns
from matplotlib import pyplot
import matplotlib.ticker as mtick
pyplot.rcParams['figure.dpi'] = 240

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

Starsim 1.0.1 (2024-07-22) — © 2023-2024 by IDM


---
# Zombie Apocalypse: a case study in Starsim calibration, workflows, and saving the world
#### Adapted from Zombies: an introduction to Starsim by Dan Klein, 2024

A deep meta-analysis of recent and historical literature reveals that there are at least three zombiism transmission routes:
1. Be attacked by a zombie, and survive (to become another zombie)!
2. Acquire zombiism congenitally through vertical transmission during the prenatal period.
3. Die of natural causes and be transformed into a zombie upon death. 

Zombies have a few distinguishing features:
* There are two distinct types of zombies: slow and fast. The fast zombies are more aggressive and have more contacts than slow zombies.
* A zombie attack will always result in one of three outcomes: the victim becomes a zombie, the victim dies, or the victim survives.
* Zombies effectively don't die naturally, but they can slow down over time.
* There are no asymptomatic zombies. All zombies show symptoms immediately upon infection.

**Zombie** is an extension of the SIR disease that adds some important features:
  * **p_fast** is the probability of a zombie being fast.
  * **dur_fast** is the duration of a zombie being fast before reverting to slow, default 1000 years.
  * **dur_inf** is the duration of zombie infection, default is 1000 years. Once a zombie, always a zombie!
  * **p_symptomatic** is the probability of showing symptoms, default assumption is 100%.
  * **p_death_on_zombie_infection** is the probability of death when converting to a zombie.

**DeathZombies** is an extension of the base **Deaths** demographic class that captures people at natural death and potentially makes them zombies
  * **p_zombie_on_natural_death** is the probability of becoming a zombie on death due to natural causes

The **Pregnancy** and **MaternalNet** modules work together to simulate pregnancy, forming network connections between mothers and pre-birth children on which disease (zombie) transmission can occur.

We include an intervention, **KillZombies** that kills only *symptomatic* zombies.

Configuring and running the simulation for one month takes less than one second.

---

# The scenario:

A zombie outbreak has occurred in a small town. The situation is grim: over the course of a single month the outbreak has spread to nearly every resident of the town. The epicenter of the outbreak is a local hospital, where the first cases were reported. There were 5 reported cases of the mysterious virus on the first day. 

Officials have quarantined the town and are asking for our help considering various interventions to control the outbreak, but first we need to be sure we understand the disease dynamics. We already have a Zombie model, so let's try to calibrate it see if we can accurately simulate the outbreak.

# To be able to model how interventions might control the outbreak, we must first understand the basic characteristics of the virus:
* beta: how likely contact with a zombie results in death and/or zombification
* p_fast: how many zombies are fast (and thus have more network contacts)
* p_death_on_zombie_infection: how likely a person is to die from a zombie attack instead of also becoming a zombie

Health officials have provided us with a dataset of the number of zombies present in the town each day. They also estimate that 20% of people who die naturally have become zombies since the onset of the outbreak.

So far every intervention attempted has had no impact, and no zombies have been killed. 

In [7]:
# Implement custom calibration class
class DailyCalibration(ss.Calibration):
    """
    A class to handle calibration of STIsim simulations. Uses the Optuna hyperparameter
    optimization library (optuna.org), which must be installed separately (via
    pip install optuna).
    Args:
        sim          (Sim)  : the simulation to calibrate
        data         (df)   : pandas dataframe
        calib_pars   (dict) : a dictionary of the parameters to calibrate of the format dict(key1=[best, low, high])
        fit_args     (dict) : a dictionary of options that are passed to sim.compute_fit() to calculate the goodness-of-fit
        par_samplers (dict) : an optional mapping from parameters to the Optuna sampler to use for choosing new points for each; by default, suggest_float
        n_trials     (int)  : the number of trials per worker
        n_workers    (int)  : the number of parallel workers (default: maximum
        total_trials (int)  : if n_trials is not supplied, calculate by dividing this number by n_workers)
        name         (str)  : the name of the database (default: 'hpvsim_calibration')
        db_name      (str)  : the name of the database file (default: 'hpvsim_calibration.db')
        keep_db      (bool) : whether to keep the database after calibration (default: false)
        storage      (str)  : the location of the database (default: sqlite)
        rand_seed    (int)  : if provided, use this random seed to initialize Optuna runs (for reproducibility)
        label        (str)  : a label for this calibration object
        die          (bool) : whether to stop if an exception is encountered (default: false)
        verbose      (bool) : whether to print details of the calibration
        kwargs       (dict) : passed to hpv.Calibration()

    Returns:
        A Calibration object
    """
    def __init__(self, sim, data, calib_pars=None, weights=None, index_name=None, fit_args=None, par_samplers=None, 
                 n_trials=None, n_workers=None, total_trials=None, name=None, db_name=None, estimator=None, keep_db=None, storage=None, rand_seed=None,sampler=None, label=None, die=False, verbose=True):

        
        import multiprocessing as mp

        # Handle run arguments
        if n_trials  is None: n_trials  = 20
        if n_workers is None: n_workers = mp.cpu_count()
        if name      is None: name      = 'starsim_calibration'
        if db_name   is None: db_name   = f'{name}.db'
        if keep_db   is None: keep_db   = False
        if storage   is None: storage   = f'sqlite:///{db_name}'
        if index_name is None: index_name = 'day'
        if total_trials is not None: n_trials = int(np.ceil(total_trials/n_workers))
        self.run_args   = sc.objdict(n_trials=int(n_trials), n_workers=int(n_workers), name=name, db_name=db_name,
                                     keep_db=keep_db, storage=storage, rand_seed=rand_seed, sampler=sampler)

        # Handle other inputs
        self.label          = label
        self.sim            = sim
        self.calib_pars     = calib_pars
        self.weights        = weights
        self.fit_args       = sc.mergedicts(fit_args)
        self.index_name = index_name
        self.par_samplers   = sc.mergedicts(par_samplers)
        self.die            = die
        self.verbose        = verbose
        self.calibrated     = False

        # Load data -- this is expecting a dataframe with a column for 'year' and other columns for to sim results
        if not isinstance(data, pd.DataFrame):
            errormsg = 'Please pass data as a pandas dataframe'
            raise ValueError(errormsg)
        self.target_data = data
        self.target_data.set_index(self.index_name, inplace=True)

        # Temporarily store a filename
        self.tmp_filename = 'tmp_calibration_%05i.obj'

        # Initialize sim
        if not self.sim.initialized:
            self.sim.initialize()

        # Figure out which sim results to get
        self.sim_result_list = self.target_data.columns.values.tolist()

        return
    
    
    def run_trial(self, trial, save=True):
        """ Define the objective for Optuna """
        if self.calib_pars is not None:
            calib_pars = self.trial_to_sim_pars(self.calib_pars, trial)
        else:
            calib_pars = None
        sim = self.run_sim(calib_pars)

        # Export results
        df_res = sim.export_df()
        df_res[self.index_name] = np.floor(np.round(df_res.index, 1)).astype(int)
        sim_results = sc.objdict()

        for skey in self.sim_result_list:
            if 'prevalence' in skey:
                model_output = df_res.groupby(by=self.index_name)[skey].mean()
            else:
                model_output = df_res.groupby(by=self.index_name)[skey].sum()
            sim_results[skey] = model_output.values

        sim_results[self.index_name] = model_output.index.values
        # Store results in temporary files
        if save:
            filename = self.tmp_filename % trial.number
            sc.save(filename, sim_results)

        # Compute fit
        fit = self.compute_fit(sim)
        return fit
   
    def compute_fit(self, sim):
        """ Compute goodness-of-fit """
        fit = 0
        df_res = sim.export_df()
        df_res[self.index_name] = np.floor(np.round(df_res.index, 1)).astype(int)
        for skey in self.sim_result_list:
            if 'prevalence' in skey:
                model_output = df_res.groupby(by=self.index_name)[skey].mean()
            else:
                model_output = df_res.groupby(by=self.index_name)[skey].sum()

            data = self.target_data[skey]
            combined = pd.merge(data, model_output, how='left', on=self.index_name)
            combined['diffs'] = combined[skey+'_x'] - combined[skey+'_y']
            gofs = ss.calibration.compute_gof(combined.dropna()[skey+'_x'], combined.dropna()[skey+'_y'])

            losses = gofs  #* self.weights[skey]
            mismatch = losses.sum()
            fit += mismatch

        return fit

In [8]:
# Define the calibration parameters
calib_pars = dict(
    diseases = dict(
        zombie = dict(
            beta = [5, 1, 100],
            p_fast = [0.1, 0.0, 1.0],
            p_death_on_zombie_infection = [0.25, 0.0, 1.0],
        ),
    ),
)

# Load the calibration data
data = pd.read_csv('./zombie_outbreak_data.csv')



# Create the Sim


people = ss.People(n_agents=5_000) # 5000 people live in this town

# Configure and create an instance of the Zombie class
zombie_pars = dict(
    init_prev = 0.001, # 5 people were initially infected
    beta = {'random': 1, 'maternal': 1}, # Guesses. To be calibrated
    p_fast = ss.bernoulli(p=0.1), # Guesses. To be calibrated
    p_death_on_zombie_infection = ss.bernoulli(p=0.25), # Guesses. To be calibrated
    p_symptomatic = ss.bernoulli(p=1.0),
)
zombie = Zombie(zombie_pars)

# This function allows the lambda parameter of the poisson distribution used to determine
# n_contacts to vary based on zombie type
def choose_degree(self, sim, uids):
    mean_degree = np.full(fill_value=4, shape=len(uids)) # Default value is 4
    zombie = sim.diseases['zombie'] 
    is_fast = zombie.infected[uids] & zombie.fast[uids]
    mean_degree[is_fast] = 50 # Fast zombies get 50
    return mean_degree

# We create two network layers, random and maternal
networks = [
    ss.RandomNet(n_contacts=ss.poisson(lam=choose_degree)),
    ss.MaternalNet()
]

# Configure and create demographic modules
death_pars = dict(
    death_rate = 15, # per 1,000 (during normal times!)
    p_zombie_on_natural_death = ss.bernoulli(p=0.2), # Estimate based on observed data
)
deaths = DeathZombies(**death_pars)
births = ss.Pregnancy(fertility_rate=175) # per 1,000 women 15-49 annually (during normal times!)
demog = [births, deaths]

# And finally bring everything together in a sim
sim_pars = dict(start=2024, end=2024+1/12, dt=1/365, verbose=0)
sim = ss.Sim(sim_pars, people=people, diseases=zombie, networks=networks, demographics=demog)


# Create the calibration object
calib = DailyCalibration(calib_pars=calib_pars,
                       sim=sim,
                       data=data,
                       total_trials=10,
                       n_workers=1,
                       name="test",
                       keep_db=False,
                       die=True)

calib.calibrate(confirm_fit=True)
# Run the sim and plot results
#sim.run()


# Package results
df = pd.DataFrame( {
    'Year': sim.yearvec,
    'Population': sim.results.n_alive,
    'Humans': sim.results.n_alive - sim.results.zombie.n_infected,
    'Zombies': sim.results.zombie.n_infected,
    'Zombie Prevalence': sim.results.zombie.prevalence,
    'Congential Zombies (cum)': sim.results.zombie.cum_congenital,
    'Zombie-Cause Mortality': sim.results.zombie.cum_deaths,
})
df['rand_seed'] = rand_seed
df['Scen'] = scen
for key, val in kwargs.items():
    df[key] = val

    return df

Error displaying custom sim repr, falling back to default: 'list' object has no attribute 'keys'
  ss.warn(f'Error displaying custom sim repr, falling back to default: {E}')
Error displaying custom sim repr, falling back to default: 'list' object has no attribute 'keys'
  ss.warn(f'Error displaying custom sim repr, falling back to default: {E}')
Error displaying custom sim repr, falling back to default: 'list' object has no attribute 'keys'
  ss.warn(f'Error displaying custom sim repr, falling back to default: {E}')
Error displaying custom sim repr, falling back to default: 'list' object has no attribute 'keys'
  ss.warn(f'Error displaying custom sim repr, falling back to default: {E}')
Trying to access non-initialized Arr object; in most cases, Arr objects need to be initialized with a Sim object, but set skip_init=True if this is intentional.
  ss.warn('Trying to access non-initialized Arr object; in most cases, Arr objects need to be initialized with a Sim object, but set skip_init=

Could not delete study, skipping...
[WinError 32] The process cannot access the file because it is being used by another process: 'test.db'
sqlite:///test.db


DuplicatedStudyError: Another study with name 'test' already exists. Please specify a different name, or reuse the existing one by setting `load_if_exists` (for Python API) or `--skip-if-exists` flag (for CLI).

In [17]:
# How does this compare with the observed data?