In [165]:
import numpy as np
import scipy as sp
from epymorph.kit import *

In [166]:
from epymorph.params import ParamFunctionNumpy


class UniformPrior:
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper

    def sample(self, size, rng):
        return sp.stats.uniform.rvs(
            loc=self.lower, scale=(self.upper - self.lower), size=size, random_state=rng
        )


class GBM(ParamFunctionNumpy):
    def __init__(self, initial=None, voliatility=0.2):
        self.initial = initial
        self.voliatility = voliatility
        super().__init__()

    def evaluate(self):
        result = np.zeros(shape=(self.time_frame.days, self.scope.nodes))
        if self.initial is None:
            # Default behavior if initial is None (e.g., initialize with 1.0)
            result[0, :] = np.ones(self.scope.nodes)
        elif isinstance(self.initial, UniformPrior):
            # Sample from the UniformPrior if initial is a UniformPrior object
            result[0, :] = self.initial.sample(self.scope.nodes, self.rng)
        else:
            # Otherwise, use the initial value directly (e.g., fixed number like 100.0)
            result[0, :] = np.full(self.scope.nodes, self.initial)
        for i in range(self.time_frame.days - 1):
            result[i + 1, :] = np.exp(
                self.rng.normal(np.log(result[i, :]), self.voliatility)
            )
        return result

In [167]:
from epymorph.adrio import acs5
from epymorph import *  # noqa: F403
from epymorph.geography.us_census import StateScope
from epymorph.data.ipm.sirh import SIRH
from epymorph.data.mm.no import No
from epymorph.initializer import Proportional
from epymorph.rume import SingleStrataRUME
from epymorph.time import TimeFrame


In [168]:
from epymorph.params import ResultDType
from numpy.typing import NDArray


class ParamLoader(ParamFunctionNumpy):
    def __init__(self, realizations: NDArray[ResultDType]):
        self.realizations = realizations

    def evaluate(self):
        return self.realizations[0, ...]

    def get_realization(self, index: int = 0):
        return self.realizations[index, ...]

In [None]:
from types import SimpleNamespace
from typing import Callable
from epymorph import initializer
from epymorph.attribute import NamePattern
from epymorph.simulation import ParamValue
from epymorph.simulator.basic.basic_simulator import RUMEType
from epymorph.util import CovariantMapping
import dataclasses


class ForecastSimulator:
    """The RUME we will use for the simulation."""

    @staticmethod
    def run(
        rume: RUMEType,
        num_realizations: int,
        params: CovariantMapping[str | NamePattern, ParamValue] | None = None,
        rng_factory: Callable[[], np.random.Generator] | None = None,
    ):
        days = rume.time_frame.days
        taus = rume.num_tau_steps
        R = num_realizations
        S = days * taus
        N = rume.scope.nodes
        C = rume.ipm.num_compartments
        E = rume.ipm.num_events

        initial = np.empty(shape=(R, N, C), dtype=np.int64)
        compartments = np.empty(shape=(R, S, N, C), dtype=np.int64)
        events = np.empty(shape=(R, S, N, E), dtype=np.int64)

        rng = (rng_factory or np.random.default_rng)()

        dummy_data = rume.evaluate_params(override_params=params, rng=rng)

        return_params = {}
        for key in dummy_data.to_dict().keys():
            # Store the results with the shape of (number of realizations, shape of the raw data)
            return_params[str(key)] = np.empty(
                shape=(R,) + dummy_data.get_raw(key).shape,
                dtype=dummy_data.get_raw(key).dtype,
            )

        if params is None:
            params = rume.params
        for i in range(R):
            new_params = {}
            for name_pattern, param_value in params.items():
                if isinstance(param_value, ParamLoader):
                    realization = param_value.get_realization(i)
                    new_params[name_pattern] = realization
                    print("realization = ", realization)
            data = rume.evaluate_params(override_params=new_params, rng=rng)
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            sim = BasicSimulator(rume)

            out = sim.run(params=evaluated_params, rng_factory=(lambda: rng))

            initial[i, ...] = out.initial
            compartments[i, ...] = out.compartments
            events[i, ...] = out.events

            for key in evaluated_params.keys():
                return_params[key][i, ...] = evaluated_params[key]

        return SimpleNamespace(
            rume=rume,
            initial=initial,
            compartments=compartments,
            events=events,
            params=return_params,
        )

    @staticmethod
    def extend(
        output,
        duration: int,
        params: CovariantMapping[str | NamePattern, ParamValue] | None = None,
        rng_factory: Callable[[], np.random.Generator] | None = None,
    ):
        forecast_start_date = output.rume.time_frame.end_date.strftime("%Y-%m-%d")

        for key, value in output.params.items():
            if len(value.shape) > 1 and value.shape[1] >= output.rume.time_frame.days:
                output.params[key] = value[:, output.rume.time_frame.days :]
        # print(output.params)
        # print(output.params["gpm:all::ipm::" + 'beta'].shape[1])

        new_params = {}
        if params is None:
            params_override = False
            params = output.rume.params
            param_names = []
            for name_pattern, param_value in params.items():  # type: ignore
                if isinstance(param_value, ParamFunctionNumpy):
                    param_names.append(name_pattern.id)  # type: ignore
                    new_params[name_pattern.id] = np.zeros(
                        output.params["gpm:all::ipm::" + name_pattern.id].shape[1]
                    )  # type: ignore
            # params = new_params
        else:
            params_override = True
            param_names = list(params.keys())

        rume = dataclasses.replace(
            output.rume, time_frame=TimeFrame.of(forecast_start_date, duration)
        )

        days = duration
        taus = rume.num_tau_steps
        R = output.compartments.shape[0]
        S = days * taus
        N = rume.scope.nodes
        C = rume.ipm.num_compartments
        E = rume.ipm.num_events

        initial = np.empty(shape=(R, N, C), dtype=np.int64)
        compartments = np.empty(shape=(R, S, N, C), dtype=np.int64)
        events = np.empty(shape=(R, S, N, E), dtype=np.int64)

        rng = (rng_factory or np.random.default_rng)()

        dummy_data = rume.evaluate_params(override_params=new_params, rng=rng)

        return_params = {}
        for key in dummy_data.to_dict().keys():
            # Store the results with the shape of (number of realizations, shape of the raw data)
            return_params[str(key)] = np.empty(
                shape=(R,) + dummy_data.get_raw(key).shape,
                dtype=dummy_data.get_raw(key).dtype,
            )

        # print("return params = ", return_params)
        # print('shape = ',return_params["gpm:all::ipm::beta"].shape)
        # print('params = ', params)

        for i in range(R):
            new_parameters = {}
            for param in param_names:
                param_array = output.params["gpm:all::ipm::" + param]
                indexing_tuple = (
                    (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
                )
                key = NamePattern(strata="*", module="*", id=param)  # type: ignore
                if isinstance(params[key], ParamLoader):  # type: ignore
                    # print('enter')
                    if params_override:
                        # If the value is a ParamLoader instance, evaluate or get a realization
                        realization = params[param].get_realization(i)  # type: ignore
                        new_parameters[param] = realization  # type: ignore
                    else:
                        if (
                            param_array.shape[1]  # type: ignore
                            < duration
                        ):
                            raise ValueError(
                                "Missing data for the parameters..check duration and "
                                "parameter space being estimated"
                            )
                        else:
                            new_parameters[param] = param_array[i]
                            # print("params aray = ", param_array[i])
                elif isinstance(params[param], ParamFunctionNumpy):  # type: ignore
                    params[param].initial = param_array[indexing_tuple]  # type: ignore
                    new_parameters[param] = params[param]  # type: ignore
                else:
                    new_parameters[name_pattern] = param_value  # type: ignore

            rume_propagate = dataclasses.replace(
                rume,
                strata=[
                    dataclasses.replace(
                        g,
                        init=initializer.Explicit(initials=output.compartments[i][-1]),
                    )  # Initialize with state values
                    for g in output.rume.strata  # For each stratum, set the initial state
                ],
            )

            data = rume_propagate.evaluate_params(
                override_params=new_parameters, rng=rng
            )
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            # print('evaluated = ', evaluated_params)
            # print('shape = ',evaluated_params["gpm:all::ipm::beta"].shape)

            sim = BasicSimulator(rume_propagate)

            # Run the simulation and collect the output based on observations
            # (dynamic params)
            out = sim.run(evaluated_params, rng_factory=(lambda: rng))

            initial[i, ...] = out.initial
            compartments[i, ...] = out.compartments
            events[i, ...] = out.events

            for key in evaluated_params.keys():
                return_params[key][i, ...] = evaluated_params[key]

        return SimpleNamespace(
            rume=rume,
            initial=initial,
            compartments=compartments,
            events=events,
            params=return_params,
        )


In [170]:
rng = np.random.default_rng()
rume = SingleStrataRUME.build(
    ipm=SIRH(),
    mm=No(),
    scope=StateScope.in_states(["AZ"], year=2015),
    init=Proportional(ratios=np.array([9999, 1, 0, 0], dtype=np.int64)),
    time_frame=TimeFrame.of("2022-10-01", 5),
    params={
        # "beta": GBM(initial=UniformPrior(lower=0.4, upper=0.41), voliatility=0.01),
        "beta": ParamLoader(rng.random((7, 10))),
        "gamma": 0.25,
        "xi": 1 / 365,  # 0.0111,
        "hospitalization_prob": 0.01,
        "hospitalization_duration": 5.0,
        "population": acs5.Population(),
    },
)

In [171]:
forecast_output = ForecastSimulator.run(rume=rume, num_realizations=7)

realization =  [0.45045798 0.1630732  0.07558923 0.6657657  0.5684364  0.05543542
 0.07792022 0.05687567 0.74831525 0.47860779]
realization =  [0.54201223 0.58452107 0.98496233 0.28819757 0.19147916 0.18434658
 0.46584541 0.61338148 0.02928364 0.18485286]
realization =  [0.99940042 0.55641859 0.26577084 0.96927116 0.41583815 0.52896151
 0.3499858  0.87790774 0.51377728 0.60192181]
realization =  [0.24445098 0.07500481 0.62652451 0.79670622 0.14903919 0.70336311
 0.8030516  0.84179637 0.31018354 0.95708124]
realization =  [0.51067941 0.09395831 0.0301252  0.7224867  0.35327166 0.36054446
 0.41820937 0.45025591 0.99936254 0.54245537]
realization =  [0.70380973 0.36718295 0.63553061 0.29370756 0.64326146 0.09847948
 0.8458998  0.18115939 0.88257431 0.63322203]
realization =  [0.24033951 0.42331997 0.97920779 0.56731855 0.5945175  0.60019737
 0.06739907 0.97381445 0.3198935  0.8834911 ]


In [172]:
forecast_output.rume.time_frame.days

5

In [173]:
forecast_output.compartments.shape

(7, 5, 1, 4)

In [174]:
forecast_output.params

{'gpm:all::ipm::beta': array([[0.45045798, 0.1630732 , 0.07558923, 0.6657657 , 0.5684364 ,
         0.05543542, 0.07792022, 0.05687567, 0.74831525, 0.47860779],
        [0.54201223, 0.58452107, 0.98496233, 0.28819757, 0.19147916,
         0.18434658, 0.46584541, 0.61338148, 0.02928364, 0.18485286],
        [0.99940042, 0.55641859, 0.26577084, 0.96927116, 0.41583815,
         0.52896151, 0.3499858 , 0.87790774, 0.51377728, 0.60192181],
        [0.24445098, 0.07500481, 0.62652451, 0.79670622, 0.14903919,
         0.70336311, 0.8030516 , 0.84179637, 0.31018354, 0.95708124],
        [0.51067941, 0.09395831, 0.0301252 , 0.7224867 , 0.35327166,
         0.36054446, 0.41820937, 0.45025591, 0.99936254, 0.54245537],
        [0.70380973, 0.36718295, 0.63553061, 0.29370756, 0.64326146,
         0.09847948, 0.8458998 , 0.18115939, 0.88257431, 0.63322203],
        [0.24033951, 0.42331997, 0.97920779, 0.56731855, 0.5945175 ,
         0.60019737, 0.06739907, 0.97381445, 0.3198935 , 0.8834911 ]]),
 'g

In [175]:
forecast_extend = ForecastSimulator.extend(
    forecast_output,
    duration=5,  # , params={"beta": GBM(initial=None, voliatility=0.01)}
)

In [176]:
forecast_extend.compartments.shape

(7, 5, 1, 4)

In [177]:
forecast_extend.params

{'gpm:all::ipm::beta': array([[0.05543542, 0.07792022, 0.05687567, 0.74831525, 0.47860779],
        [0.18434658, 0.46584541, 0.61338148, 0.02928364, 0.18485286],
        [0.52896151, 0.3499858 , 0.87790774, 0.51377728, 0.60192181],
        [0.70336311, 0.8030516 , 0.84179637, 0.31018354, 0.95708124],
        [0.36054446, 0.41820937, 0.45025591, 0.99936254, 0.54245537],
        [0.09847948, 0.8458998 , 0.18115939, 0.88257431, 0.63322203],
        [0.60019737, 0.06739907, 0.97381445, 0.3198935 , 0.8834911 ]]),
 'gpm:all::ipm::gamma': array([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]),
 'gpm:all::ipm::xi': array([0.00273973, 0.00273973, 0.00273973, 0.00273973, 0.00273973,
        0.00273973, 0.00273973]),
 'gpm:all::ipm::hospitalization_prob': array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01]),
 'gpm:all::ipm::hospitalization_duration': array([5., 5., 5., 5., 5., 5., 5.]),
 'gpm:all::init::population': array([[25895968444448860],
        [23925768161198147],
        [32370111954616435],
   

In [32]:
forecast_output.params

{'gpm:all::ipm::beta': array([[0.96552842, 0.13095013, 0.44895435, 0.05187637, 0.92524406],
        [0.58430256, 0.8251559 , 0.96852491, 0.45787162, 0.28761924],
        [0.38929357, 0.40932384, 0.37978491, 0.68732463, 0.05499807],
        [0.57324701, 0.98129156, 0.56744396, 0.08360449, 0.45954573],
        [0.32684272, 0.69944276, 0.14999393, 0.52671305, 0.34052873],
        [0.09256068, 0.75083752, 0.36588713, 0.37690154, 0.2087942 ],
        [0.07830595, 0.26561788, 0.75764892, 0.27503524, 0.86674014],
        [0.13954575, 0.96963345, 0.07376746, 0.98094978, 0.75382525],
        [0.41940521, 0.91047526, 0.95029018, 0.98575689, 0.65672581],
        [0.37680579, 0.10087524, 0.66243995, 0.05149404, 0.22546074]]),
 'gpm:all::ipm::gamma': array([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25]),
 'gpm:all::ipm::xi': array([0.00273973, 0.00273973, 0.00273973, 0.00273973, 0.00273973,
        0.00273973, 0.00273973, 0.00273973, 0.00273973, 0.00273973]),
 'gpm:all::ipm::hospitali

In [35]:
for param in param_names:
    param_array = forecast_output.params["gpm:all::ipm::" + param]
    print(param_array.shape[1])

5


In [74]:
param_names = [key.split("::")[-1] for key in forecast_output.params.keys()]
param_names

['beta',
 'gamma',
 'xi',
 'hospitalization_prob',
 'hospitalization_duration',
 'population',
 'label']

In [101]:
forecast_output.params["gpm:all::ipm::beta"][0]

array([[0.40445935],
       [0.40286018],
       [0.39853281],
       [0.39466837],
       [0.39114458],
       [0.38441981],
       [0.38343928],
       [0.38083885],
       [0.38117172],
       [0.37722004]])

In [None]:
len(forecast_output.compartments)  # .shape[0]

100

In [None]:
rume = SingleStrataRUME.build(
    ipm=SIRH(),
    mm=No(),
    scope=StateScope.in_states(["AZ"], year=2015),
    init=Proportional(ratios=np.array([9999, 1, 0, 0], dtype=np.int64)),
    time_frame=TimeFrame.of("2022-10-01", 10),
    params={
        "beta": GBM(initial=None, voliatility=0.01),
        "gamma": 0.25,
        "xi": 1 / 365,  # 0.0111,
        "hospitalization_prob": 0.01,
        "hospitalization_duration": 5.0,
        "population": acs5.Population(),
    },
)

In [None]:
extracted_params = {}

# Loop over rume.params and extract key-value pairs
for key, value in rume.params.items():
    # You can use the 'id' as the key for simplicity
    extracted_params[key.id] = value

extracted_params

{'beta': <__main__.GBM at 0x20ab8c17710>,
 'gamma': 0.25,
 'xi': 0.0027397260273972603,
 'hospitalization_prob': 0.01,
 'hospitalization_duration': 5.0,
 'population': <epymorph.adrio.acs5.Population at 0x20ab8f70c90>}

In [109]:
rume.params
data = rume.evaluate_params(rng=np.random.default_rng(seed=1))
[data.get_raw(key) for key in data.to_dict().keys()]

[array([[0.40511822],
        [0.40846045],
        [0.40981239],
        [0.40450654],
        [0.40818539],
        [0.4100115 ],
        [0.40781583],
        [0.41019262],
        [0.4116908 ],
        [0.4129035 ]]),
 array(0.25),
 array(0.00273973),
 array(0.01),
 array(5.),
 array([6641928], dtype=int64),
 array(['AZ'], dtype='<U2')]

In [None]:
params = {"beta": GBM(initial=None, voliatility=0.01)}


def run(output, params):
    param_names = list(params.keys())
    for i in range(1):
        parameters = {}
        for param in param_names:
            param_array = output.params["gpm:all::ipm::" + param]
            indexing_tuple = (
                (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
            )
            if not isinstance(params[param], (int, float, np.ndarray)):
                params[param].initial = param_array[indexing_tuple]
                parameters[param] = params[param]
            else:
                parameters[param] = params[param]

            data = rume.evaluate_params(
                override_params=parameters, rng=np.random.default_rng(seed=1)
            )
            evaluated_params = dict(
                zip(
                    [str(key) for key in data.to_dict().keys()],
                    [data.get_raw(key) for key in data.to_dict().keys()],
                )
            )

            print(evaluated_params)

In [None]:
run(forecast_output, params=None, rume=rume)

{'gpm:all::ipm::beta': array([[0.37722004],
       [0.37852591],
       [0.38164876],
       [0.38291195],
       [0.37795438],
       [0.38139175],
       [0.38309799],
       [0.38104644],
       [0.38326722],
       [0.38466706]]), 'gpm:all::ipm::gamma': array(0.25), 'gpm:all::ipm::xi': array(0.00273973), 'gpm:all::ipm::hospitalization_prob': array(0.01), 'gpm:all::ipm::hospitalization_duration': array(5.), 'gpm:all::init::population': array([6641928], dtype=int64), 'meta::geo::label': array(['AZ'], dtype='<U2')}


In [None]:
def run(output, params):
    param_names = list(params.keys())
    rume = output.rume
    evaluated_params = {}
    out = None
    forecast_start_date = rume.time_frame.end_date.strftime("%Y-%m-%d")
    for i in range(1):
        parameters = {}
        for param in param_names:
            param_array = output.params["gpm:all::ipm::" + param]
            indexing_tuple = (
                (i, -1, slice(None)) if param_array.ndim == 3 else (i, slice(None))
            )
            if not isinstance(params[param], (int, float, np.ndarray)):
                params[param].initial = param_array[indexing_tuple]
                parameters[param] = params[param]
            else:
                parameters[param] = params[param]

        rume_propagate = dataclasses.replace(
            rume,
            time_frame=TimeFrame.of(forecast_start_date, 10),
            strata=[
                dataclasses.replace(
                    g,
                    init=initializer.Explicit(initials=output.compartments[i][-1]),
                )  # Initialize with state values
                for g in output.rume.strata  # For each stratum, set the initial state
            ],
        )

        data = rume_propagate.evaluate_params(
            override_params=parameters, rng=np.random.default_rng(seed=1)
        )
        evaluated_params = dict(
            zip(
                [str(key) for key in data.to_dict().keys()],
                [data.get_raw(key) for key in data.to_dict().keys()],
            )
        )

        evaluated_params = evaluated_params
        print(evaluated_params)

        sim = BasicSimulator(rume_propagate)

        out = sim.run(evaluated_params)

    return evaluated_params, out


In [70]:
forecast_output.compartments[0][-1]

array([[6639471,    1358,    1094,       5]], dtype=int64)

In [71]:
evaluated_params, out = run(
    forecast_output, params={"beta": GBM(initial=None, voliatility=0.01)}
)

{'gpm:all::ipm::beta': array([[0.4105499 ],
       [0.41197115],
       [0.41536992],
       [0.41674473],
       [0.41134912],
       [0.4150902 ],
       [0.4169472 ],
       [0.41471439],
       [0.41713139],
       [0.41865491]]), 'gpm:all::ipm::gamma': array(0.25), 'gpm:all::ipm::xi': array(0.00273973), 'gpm:all::ipm::hospitalization_prob': array(0.01), 'gpm:all::ipm::hospitalization_duration': array(5.), 'meta::geo::label': array(['AZ'], dtype='<U2')}


In [None]:
sim = BasicSimulator(rume)

# Run the simulation and collect the output based on observations
# (dynamic params)
out = sim.run(evaluated_params)

In [87]:
not isinstance(params["beta"], (int, float, str, bool))

True

In [None]:
param_types = {}

for key, value in params.items():
    if callable(value):
        param_types[key] = "function"
    elif isinstance(value, object) and not isinstance(
        value, type
    ):  # Check if it's an instance of any class
        param_types[key] = "class instance"
    elif isinstance(
        value, type
    ):  # This means it's a class type itself, not an instance
        param_types[key] = "class"
    elif isinstance(
        value, (int, float, str, bool)
    ):  # Check if it's a static value (number, string, etc.)
        param_types[key] = "static value"
    else:
        param_types[key] = "unknown"

In [71]:
param_types

{'beta': 'class instance', 'gama': 'class instance'}