In [None]:
import numpy as np
from epymorph.kit import *

from epymorph.adrio import acs5
from epymorph import *  # noqa: F403
from epymorph.geography.us_census import StateScope
from epymorph.data.ipm.sirh import SIRH
from epymorph.data.mm.no import No
from epymorph.initializer import Proportional
from epymorph.rume import SingleStrataRUME
from epymorph.time import TimeFrame

In [None]:
from epymorph.params import ParamFunctionNumpy
import scipy as sp


class UniformPrior:
    def __init__(self, lower, upper):
        self.lower = lower
        self.upper = upper

    def sample(self, size, rng):
        return sp.stats.uniform.rvs(
            loc=self.lower, scale=(self.upper - self.lower), size=size, random_state=rng
        )


class GBM(ParamFunctionNumpy):
    def __init__(self, initial=None, voliatility=0.2):
        self.initial = initial
        self.voliatility = voliatility
        super().__init__()

    def evaluate(self):
        result = np.zeros(shape=(self.time_frame.days, self.scope.nodes))
        if self.initial is None:
            # Default behavior if initial is None (e.g., initialize with 1.0)
            result[0, :] = np.ones(self.scope.nodes)
        elif isinstance(self.initial, UniformPrior):
            # Sample from the UniformPrior if initial is a UniformPrior object
            result[0, :] = self.initial.sample(self.scope.nodes, self.rng)
        else:
            # Otherwise, use the initial value directly (e.g., fixed number like 100.0)
            result[0, :] = np.full(self.scope.nodes, self.initial)
        for i in range(self.time_frame.days - 1):
            result[i + 1, :] = np.exp(
                self.rng.normal(np.log(result[i, :]), self.voliatility)
            )
        return result

In [None]:
class ForecastSimulator:
    def __init__(self, rume, run_particlefilter: bool, forecast_time_frame):
        return None

In [None]:
duration = 7 * 14
t = np.arange(0, duration)
true_beta = 0.03 * np.cos(t * 2 * np.pi / (365)) + 0.28

rume = SingleStrataRUME.build(
    ipm=SIRH(),
    mm=No(),
    scope=StateScope.in_states(["AZ"], year=2015),
    init=Proportional(ratios=np.array([9999, 1, 0, 0], dtype=np.int64)),
    time_frame=TimeFrame.of("2022-10-01", 7 * 14),
    params={
        "beta": GBM(initial=UniformPrior(lower=0.4, upper=0.41), voliatility=0.01),
        "gamma": 0.25,
        "xi": 1 / 365,  # 0.0111,
        "hospitalization_prob": 0.01,
        "hospitalization_duration": 5.0,
        "population": acs5.Population(),
    },
)