In [1]:
from scipy.stats import truncnorm
import pandas as pd
import numpy as onp
import itertools
import datetime
import tqdm
import sys
import os

import jax.numpy as np

def flatten_list(list_array):
    return list(itertools.chain(*list_array))

sys.path.insert(0,"../")
from global_config import config

results_dir           = config.get_property('results_dir')
data_dir              = config.get_property('data_dir')
paper_dir             = config.get_property('paper_dir')
data_db_dir           = config.get_property('data_db_dir')
feb_hosp_records_path = os.path.join(data_db_dir, 'long_files_8_25_2021')
path_to_save          = os.path.join(results_dir, "real_testing", "community")


In [2]:
class Patient:
    susceptible = 0
    colonized   = 1


In [3]:
dates_simulation = pd.date_range(start="2020-02-01", end="2021-02-28", freq="D")

movement_df               = pd.read_csv(os.path.join(data_db_dir, "long_files_8_25_2021", 'patient_movement_2022-Nov.csv'), parse_dates=['date']).drop_duplicates(subset=["date", "mrn"], keep="first")
movement_df["ward_total"] = movement_df.apply(lambda x: x["ward"]+"-"+x["building"]+"-"+x["place"], axis=1)
movement_df               = movement_df[movement_df["date"].isin(dates_simulation)]

mrd2id  = {mrn: id for id, mrn in enumerate(movement_df.mrn.unique())}
ward2id = {ward_name: id for id, ward_name in enumerate(movement_df.ward_total.unique())}

movement_df["mrn_id"]  = movement_df.mrn.map(mrd2id)
movement_df["ward_id"] = movement_df.ward_total.map(ward2id)

ward_size_df                 = movement_df.reset_index()
ward_size_df["ward_id"]      = ward_size_df["ward_total"].apply(lambda x: ward2id[x])
ward_size_df["num_patients"] = 1
ward_size_df                 = ward_size_df.groupby(["date", "ward", "ward_id"]).sum()[["num_patients"]].reset_index().drop(columns=["date"])
ward_size_df                 = ward_size_df.groupby(["ward", "ward_id"]).mean().reset_index().sort_values(by="num_patients")
ward2size                    = {r.ward_id: r.num_patients for idx_r, r in ward_size_df.iterrows()}


  ward_size_df                 = ward_size_df.groupby(["date", "ward", "ward_id"]).sum()[["num_patients"]].reset_index().drop(columns=["date"])


In [4]:
from jax import random

Np  = 10000
key = random.PRNGKey(0)



In [47]:
from jax import random

@jit
def decolonize(agents_state, alpha, Np):
    return agents_state - agents_state * (random.uniform(key=key, shape=(Np, )) < alpha)

def amr_abm(agents_state, movement, ward2size, parameters):
    """ Agent based model tracking colonized and susceptible patients with pre-defined movement patterns.

    Args:
        agents_state : agent state. {0: Patient.susceptible, 1: Patient.colonized}  Size: (n_patients)
        movement     : pd.Dataframe with patient locations and culture information.
        parameters   : dictionary of parameters, contains importation rate (gamma), nosocomial transmission rate (beta),
                        effective sensitivity (ro), and decolonization rate (alpha)
    """

    γ   = parameters["gamma"] # importation rate.
    β   = parameters["beta"]  # nosocomial transmission rate.
    ρ   = parameters["ro"]    # effective sensitivity.
    α   = parameters["alpha"] # decolonization rate
    Nw  = parameters["Nw"]    # number of wards
    Np  = parameters["Np"]    # number of patients

    # Decolonize patients according to. P(C2S) = α
    #agents_state.at[:].set(np.maximum(agents_state - agents_state * (random.uniform(key=key, shape=(Np, )) < α), 0))
    #agents_state = agents_state - agents_state * (onp.random.uniform(size=(Np, )) < α)
    agents_state = np.add(agents_state, - agents_state * (onp.random.uniform(size=(Np, )) < α))
    agents_state = np.clip(agents_state, 0, 1)

    # import patients.
    new_patients = movement[movement["first_day"]==1]["mrn_id"].values
    if new_patients.shape[0] > 0:
        # P(S2C) = \gamma - Probability of colonization given importation.
        #agents_state[new_patients] = Patient.colonized * (random.uniform(key=key, shape=(new_patients.shape[0], )) < γ)
        agents_state[new_patients] = Patient.colonized * (onp.random.uniform(size=(new_patients.shape[0], )) < γ)

    # compute force of infection for each ward.
    for i, ward_id in enumerate(movement["ward_id"].unique()):

        patients_ward = movement[movement["ward_id"]==ward_id]["mrn_id"].values

        # λ_i = β  * C / N  - Force of infection for ward i.
        λ_i = β * np.sum(agents_state[patients_ward]==Patient.colonized) / ward2size[ward_id]

        # P(C2S)_i = λ_i, we add the state but if already colonized the state would be 2 so we clip it to 1.
        #agents_state[patients_ward] = Patient.colonized * np.minimum(agents_state[patients_ward] + random.uniform(key=key, shape=(patients_ward.shape[0], )) < λ_i, 1)
        agents_state.at[patients_ward].set(np.add(agents_state[patients_ward],  Patient.colonized * (onp.random.uniform(size=(patients_ward.shape[0], )) < λ_i)))
        agents_state.at[patients_ward].set(np.clip(agents_state[patients_ward], 0, 1))

    return agents_state


In [46]:
from jax import jit, vmap

dates_simulation = pd.date_range(start="2020-02-01", end="2021-02-28", freq="D")

parameters          = {}
parameters["gamma"] = 15/100
parameters["beta"]  = 0.02
parameters["ro"]    = 6/100
parameters["alpha"] = 1/120
parameters["Nw"]    = len(ward2id)
parameters["Np"]    = movement_df.mrn_id.unique().shape[0]

colonized    = []

key   = random.PRNGKey(0)
model = lambda pstatus, movement: amr_abm(pstatus, movement, ward2size, parameters)

from tqdm import tqdm
import numpy as onp

agents_state = np.zeros(parameters["Np"])

for date in tqdm(dates_simulation):
    movement     = movement_df[movement_df["date"]==date]
    agents_state = model(agents_state, movement)
    colonized.append(np.sum(agents_state==Patient.colonized))

from utils.plot_utils import *

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(dates_simulation, colonized)
format_axis(ax)


  0%|          | 0/394 [00:00<?, ?it/s]


AttributeError: 'numpy.ndarray' object has no attribute 'at'

In [44]:
model_ensemble = vmap(amr_abm, in_axes=(0, None, None, None))
agents_state   = np.zeros((2, movement_df.mrn_id.unique().shape[0]))

for i, date in enumerate(dates_simulation):
    movement     = movement_df[movement_df["date"]==date]
    print(i)
    agents_state = model_ensemble(agents_state, movement, ward2size, parameters)


0


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[103510])>with<BatchTrace(level=1/0)> with
  val = array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])
  batch_dim = 0
This BatchTracer with object id 140283539435328 was created on line:
  /var/folders/7p/jp6xqkvn5wb6ddl1fn0bhs980000gn/T/ipykernel_90878/2888077412.py:7 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError