In [1]:
import numpy as np
import scipy as sp
from epymorph.kit import *

In [2]:
from epymorph.params import ParamFunctionNumpy


class UniformPrior:
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper

    def sample(self, size, rng):
        return sp.stats.uniform.rvs(
            loc=self.lower, scale=(self.upper - self.lower), size=size, random_state=rng
        )


class GBM(ParamFunctionNumpy):
    def __init__(self, initial=None, voliatility=0.2):
        self.initial = initial
        self.voliatility = voliatility
        super().__init__()

    def evaluate(self):
        result = np.zeros(shape=(self.time_frame.days, self.scope.nodes))
        if self.initial is None:
            # Default behavior if initial is None (e.g., initialize with 1.0)
            result[0, :] = np.ones(self.scope.nodes)
        elif isinstance(self.initial, UniformPrior):
            # Sample from the UniformPrior if initial is a UniformPrior object
            result[0, :] = self.initial.sample(self.scope.nodes, self.rng)
        else:
            # Otherwise, use the initial value directly (e.g., fixed number like 100.0)
            result[0, :] = np.full(self.scope.nodes, self.initial)
        for i in range(self.time_frame.days - 1):
            result[i + 1, :] = np.exp(
                self.rng.normal(np.log(result[i, :]), self.voliatility)
            )
        return result

In [3]:
from epymorph.adrio import acs5
from epymorph import *  # noqa: F403
from epymorph.geography.us_census import StateScope
from epymorph.data.ipm.sirh import SIRH
from epymorph.data.mm.no import No
from epymorph.initializer import Proportional
from epymorph.rume import SingleStrataRUME
from epymorph.time import TimeFrame


In [146]:
from types import SimpleNamespace
from typing import Callable
from epymorph import initializer
from epymorph.attribute import NamePattern
from epymorph.simulation import ParamValue
from epymorph.simulator.basic.basic_simulator import RUMEType
from epymorph.util import CovariantMapping
import dataclasses


class ForecastSimulator:
    """The RUME we will use for the simulation."""

    @staticmethod
    def run(
        rume: RUMEType,
        num_realizations: int,
        params: CovariantMapping[str | NamePattern, ParamValue] | None = None,
        rng_factory: Callable[[], np.random.Generator] | None = None,
    ):
        days = rume.time_frame.days
        taus = rume.num_tau_steps
        R = num_realizations
        S = days * taus
        N = rume.scope.nodes
        C = rume.ipm.num_compartments
        E = rume.ipm.num_events

        initial = np.empty(shape=(R, N, C), dtype=np.int64)
        compartments = np.empty(shape=(R, S, N, C), dtype=np.int64)
        events = np.empty(shape=(R, S, N, E), dtype=np.int64)

        rng = (rng_factory or np.random.default_rng)()

        dummy_data = rume.evaluate_params(override_params=params, rng=rng)

        return_params = {}
        for key in dummy_data.to_dict().keys():
            # Store the results with the shape of (number of realizations, shape of the raw data)
            return_params[str(key)] = np.empty(
                shape=(R,) + dummy_data.get_raw(key).shape,
                dtype=dummy_data.get_raw(key).dtype,
            )

        for i in range(R):
            data = rume.evaluate_params(override_params=params, rng=rng)
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            sim = BasicSimulator(rume)

            out = sim.run(params=evaluated_params, rng_factory=(lambda: rng))

            initial[i, ...] = out.initial
            compartments[i, ...] = out.compartments
            events[i, ...] = out.events

            for key in evaluated_params.keys():
                return_params[key][i, ...] = evaluated_params[key]

        return SimpleNamespace(
            rume=rume,
            initial=initial,
            compartments=compartments,
            events=events,
            params=return_params,
        )

    @staticmethod
    def extend(
        output,
        duration: int,
        params: dict,  # CovariantMapping[str | NamePattern, ParamValue],
        rng_factory: Callable[[], np.random.Generator] | None = None,
    ):
        rume = output.rume
        days = duration
        taus = rume.num_tau_steps
        R = output.compartments.shape[0]
        print("num realisations = ", R)
        S = days * taus
        N = rume.scope.nodes
        C = rume.ipm.num_compartments
        E = rume.ipm.num_events
        forecast_start_date = rume.time_frame.end_date.strftime("%Y-%m-%d")

        initial = np.empty(shape=(R, N, C), dtype=np.int64)
        compartments = np.empty(shape=(R, S, N, C), dtype=np.int64)
        events = np.empty(shape=(R, S, N, E), dtype=np.int64)

        rng = (rng_factory or np.random.default_rng)()

        param_names = list(params.keys())
        parameters = {}
        for param in param_names:
            parameters[param] = np.array(
                [0.2]
            )  # using a default value to create dummy data

        dummy_data = rume.evaluate_params(override_params=parameters, rng=rng)

        return_params = {}
        for key in dummy_data.to_dict().keys():
            # Store the results with the shape of (number of realizations, shape of the raw data)
            return_params[str(key)] = np.empty(
                shape=(R,) + dummy_data.get_raw(key).shape,
                dtype=dummy_data.get_raw(key).dtype,
            )

        print("return_params = ", return_params["gpm:all::ipm::beta"].shape)

        for i in range(R):
            parameters = {}
            for param in param_names:
                param_array = output.params["gpm:all::ipm::" + param]
                indexing_tuple = (
                    (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
                )
                if not isinstance(params[param], (int, float, np.ndarray)):
                    params[param].initial = param_array[indexing_tuple]
                    parameters[param] = params[param]
                else:
                    parameters[param] = params[param]

            rume_propagate = dataclasses.replace(
                rume,
                time_frame=TimeFrame.of(forecast_start_date, duration),
                strata=[
                    dataclasses.replace(
                        g,
                        init=initializer.Explicit(initials=output.compartments[i][-1]),
                    )  # Initialize with state values
                    for g in output.rume.strata  # For each stratum, set the initial state
                ],
            )

            data = rume_propagate.evaluate_params(
                override_params=parameters, rng=np.random.default_rng(seed=1)
            )
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            print("evaluated_params = ", evaluated_params["gpm:all::ipm::beta"].shape)

            sim = BasicSimulator(rume_propagate)

            # Run the simulation and collect the output based on observations
            # (dynamic params)
            out = sim.run(evaluated_params, rng_factory=(lambda: rng))

            initial[i, ...] = out.initial
            compartments[i, ...] = out.compartments
            events[i, ...] = out.events

            for key in evaluated_params.keys():
                print("key = ", key)
                print(
                    "return params[key] ", return_params[key].shape, return_params[key]
                )
                print(
                    "evaluated_params[key] ",
                    evaluated_params[key].shape,
                    evaluated_params[key],
                )
                return_params[key][i, ...] = evaluated_params[key]

        return SimpleNamespace(
            rume=rume,
            initial=initial,
            compartments=compartments,
            events=events,
            params=return_params,
        )


In [148]:
rume = SingleStrataRUME.build(
    ipm=SIRH(),
    mm=No(),
    scope=StateScope.in_states(["AZ"], year=2015),
    init=Proportional(ratios=np.array([9999, 1, 0, 0], dtype=np.int64)),
    time_frame=TimeFrame.of("2022-10-01", 2),
    params={
        "beta": GBM(initial=UniformPrior(lower=0.4, upper=0.41), voliatility=0.01),
        "gamma": 0.25,
        "xi": 1 / 365,  # 0.0111,
        "hospitalization_prob": 0.01,
        "hospitalization_duration": 5.0,
        "population": acs5.Population(),
    },
)

In [149]:
forecast_output = ForecastSimulator.run(rume=rume, num_realizations=10)

In [150]:
forecast_output.compartments.shape

(10, 2, 1, 4)

In [None]:
forecast_output.params

In [152]:
forecast_extend = ForecastSimulator.extend(
    forecast_output,
    5,
    params={"beta": GBM(initial=None, voliatility=0.01)},
)

num realisations =  10
return_params =  (10, 1)
evaluated_params =  (5, 1)
key =  gpm:all::ipm::beta
return params[key]  (10, 1) [[5.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]
 [5.]]
evaluated_params[key]  (5, 1) [[0.40902122]
 [0.41043718]
 [0.4138233 ]
 [0.41519298]
 [0.40981747]]


ValueError: could not broadcast input array from shape (5,1) into shape (1,)

In [74]:
param_names = [key.split("::")[-1] for key in forecast_output.params.keys()]
param_names

['beta',
 'gamma',
 'xi',
 'hospitalization_prob',
 'hospitalization_duration',
 'population',
 'label']

In [101]:
forecast_output.params["gpm:all::ipm::beta"][0]

array([[0.40445935],
       [0.40286018],
       [0.39853281],
       [0.39466837],
       [0.39114458],
       [0.38441981],
       [0.38343928],
       [0.38083885],
       [0.38117172],
       [0.37722004]])

In [None]:
len(forecast_output.compartments)  # .shape[0]

100

In [None]:
rume = SingleStrataRUME.build(
    ipm=SIRH(),
    mm=No(),
    scope=StateScope.in_states(["AZ"], year=2015),
    init=Proportional(ratios=np.array([9999, 1, 0, 0], dtype=np.int64)),
    time_frame=TimeFrame.of("2022-10-01", 10),
    params={
        "beta": GBM(initial=None, voliatility=0.01),
        "gamma": 0.25,
        "xi": 1 / 365,  # 0.0111,
        "hospitalization_prob": 0.01,
        "hospitalization_duration": 5.0,
        "population": acs5.Population(),
    },
)

In [None]:
extracted_params = {}

# Loop over rume.params and extract key-value pairs
for key, value in rume.params.items():
    # You can use the 'id' as the key for simplicity
    extracted_params[key.id] = value

extracted_params

{'beta': <__main__.GBM at 0x20ab8c17710>,
 'gamma': 0.25,
 'xi': 0.0027397260273972603,
 'hospitalization_prob': 0.01,
 'hospitalization_duration': 5.0,
 'population': <epymorph.adrio.acs5.Population at 0x20ab8f70c90>}

In [109]:
rume.params
data = rume.evaluate_params(rng=np.random.default_rng(seed=1))
[data.get_raw(key) for key in data.to_dict().keys()]

[array([[0.40511822],
        [0.40846045],
        [0.40981239],
        [0.40450654],
        [0.40818539],
        [0.4100115 ],
        [0.40781583],
        [0.41019262],
        [0.4116908 ],
        [0.4129035 ]]),
 array(0.25),
 array(0.00273973),
 array(0.01),
 array(5.),
 array([6641928], dtype=int64),
 array(['AZ'], dtype='<U2')]

In [None]:
params = {"beta": GBM(initial=None, voliatility=0.01)}


def run(output, params):
    param_names = list(params.keys())
    for i in range(1):
        parameters = {}
        for param in param_names:
            param_array = output.params["gpm:all::ipm::" + param]
            indexing_tuple = (
                (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
            )
            if not isinstance(params[param], (int, float, np.ndarray)):
                params[param].initial = param_array[indexing_tuple]
                parameters[param] = params[param]
            else:
                parameters[param] = params[param]

            data = rume.evaluate_params(
                override_params=parameters, rng=np.random.default_rng(seed=1)
            )
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            print(evaluated_params)

In [None]:
run(forecast_output, params=None, rume=rume)

{'gpm:all::ipm::beta': array([[0.37722004],
       [0.37852591],
       [0.38164876],
       [0.38291195],
       [0.37795438],
       [0.38139175],
       [0.38309799],
       [0.38104644],
       [0.38326722],
       [0.38466706]]), 'gpm:all::ipm::gamma': array(0.25), 'gpm:all::ipm::xi': array(0.00273973), 'gpm:all::ipm::hospitalization_prob': array(0.01), 'gpm:all::ipm::hospitalization_duration': array(5.), 'gpm:all::init::population': array([6641928], dtype=int64), 'meta::geo::label': array(['AZ'], dtype='<U2')}


In [None]:
def run(output, params):
    param_names = list(params.keys())
    rume = output.rume
    evaluated_params = {}
    out = None
    forecast_start_date = rume.time_frame.end_date.strftime("%Y-%m-%d")
    for i in range(1):
        parameters = {}
        for param in param_names:
            param_array = output.params["gpm:all::ipm::" + param]
            indexing_tuple = (
                (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
            )
            if not isinstance(params[param], (int, float, np.ndarray)):
                params[param].initial = param_array[indexing_tuple]
                parameters[param] = params[param]
            else:
                parameters[param] = params[param]

        rume_propagate = dataclasses.replace(
            rume,
            time_frame=TimeFrame.of(forecast_start_date, 10),
            strata=[
                dataclasses.replace(
                    g,
                    init=initializer.Explicit(initials=output.compartments[i][-1]),
                )  # Initialize with state values
                for g in output.rume.strata  # For each stratum, set the initial state
            ],
        )

        data = rume_propagate.evaluate_params(
            override_params=parameters, rng=np.random.default_rng(seed=1)
        )
        evaluated_params = dict(
            zip(
                [str(key) for key in data.to_dict().keys()],
                [data.get_raw(key) for key in data.to_dict().keys()],
            )
        )

        evaluated_params = evaluated_params
        print(evaluated_params)

        sim = BasicSimulator(rume_propagate)

        out = sim.run(evaluated_params)

    return evaluated_params, out


In [70]:
forecast_output.compartments[0][-1]

array([[6639471,    1358,    1094,       5]], dtype=int64)

In [71]:
evaluated_params, out = run(
    forecast_output, params={"beta": GBM(initial=None, voliatility=0.01)}
)

{'gpm:all::ipm::beta': array([[0.4105499 ],
       [0.41197115],
       [0.41536992],
       [0.41674473],
       [0.41134912],
       [0.4150902 ],
       [0.4169472 ],
       [0.41471439],
       [0.41713139],
       [0.41865491]]), 'gpm:all::ipm::gamma': array(0.25), 'gpm:all::ipm::xi': array(0.00273973), 'gpm:all::ipm::hospitalization_prob': array(0.01), 'gpm:all::ipm::hospitalization_duration': array(5.), 'meta::geo::label': array(['AZ'], dtype='<U2')}


In [None]:
sim = BasicSimulator(rume)

# Run the simulation and collect the output based on observations
# (dynamic params)
out = sim.run(evaluated_params)

In [87]:
not isinstance(params["beta"], (int, float, str, bool))

True

In [None]:
param_types = {}

for key, value in params.items():
    if callable(value):
        param_types[key] = "function"
    elif isinstance(value, object) and not isinstance(
        value, type
    ):  # Check if it's an instance of any class
        param_types[key] = "class instance"
    elif isinstance(
        value, type
    ):  # This means it's a class type itself, not an instance
        param_types[key] = "class"
    elif isinstance(
        value, (int, float, str, bool)
    ):  # Check if it's a static value (number, string, etc.)
        param_types[key] = "static value"
    else:
        param_types[key] = "unknown"

In [71]:
param_types

{'beta': 'class instance', 'gama': 'class instance'}