## Short Demonstration Of `pyrenew-flu-light`, 2024-08-27

_This post proceeds through the following._

1. Background (`cfaepim`) 
   * Show `cfa-forecast-renewal-epidemia` & `cfa-flu-eval` 
   * Show website,`cfaepim` forecasts, Epidemia, model mathematics 
2. The current re-implementation, i.e. `pyrenew-flu-light`.
3. Reflections on using PyRenew for development.
4. Feel free to provide feedback, there is much needed.

The __goal__ was to use the re-implementation of `cfaepim` to find issues with PyRenew.

(see Quarto site)

### Current Instantiation

* Historical mode only, requires neat data.
* Eventually hope to connect to API, NHSN or NSSP.
* KS-Test + quantitative evaluation in progress.

How To Run: _Model use command line arguments, at present._

```
python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions NY --historical

python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions TX --historical --forecast 
```

In [9]:
import subprocess
import matplotlib.pyplot as plt

# python3 tut_epim_port_msr.py --reporting_date 2024-03-30 --regions NY --historical --forecast
subprocess.run(
    [
        "python3",
        "tut_epim_port_msr.py",
        "--reporting_date",
        "2024-03-30",
        "--regions",
        "NY",
        "--historical",
        "--forecast",
    ]
)

plt.show()


Caching the list of root modules, please wait!
(This will only be done once - type '%rehashx' to reset cache!)



2024-08-27 10:39:53,981 - INFO - Starting CFAEPIM
2024-08-27 10:39:53,981 - INFO - Number of cores set.
2024-08-27 10:39:53,981 - INFO - Output directory ensured working.
2024-08-27 10:39:53,982 - INFO - Configuration (historical) loaded.
2024-08-27 10:39:53,999 - INFO - Incidence data (historical) loaded.
2024-08-27 10:39:54,002 - INFO - NY: Dataset w/ pre-observation ready.
2024-08-27 10:39:54,003 - INFO - NY: Dataset w/ post-observation ready.
2024-08-27 10:39:54,004 - INFO - NY: Variables extracted from dataset.
2024-08-27 10:39:54,014 - INFO - Unable to initialize backend 'cuda': 
2024-08-27 10:39:54,014 - INFO - Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
2024-08-27 10:39:54,015 - INFO - Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such

./output/Historical_2024-03-30/


2024-08-27 10:39:54,188 - INFO - Initializing CFAEPIM_Observation
2024-08-27 10:39:54,188 - INFO - Validating CFAEPIM_Observation parameters
2024-08-27 10:39:54,188 - INFO - Initializing alpha process
2024-08-27 10:39:54,188 - INFO - Initializing negative binomial process
2024-08-27 10:39:54,188 - INFO - NY: CFAEPIM model instantiated (fitting)!
sample: 100%|██████████| 2000/2000 [01:12<00:00, 27.74it/s, 255 steps of size 1.53e-02. acc. prob=0.92]
2024-08-27 10:41:14,726 - INFO - NY: CFAEPIM model (fitting) ran!



                                mean       std    median      5.0%     95.0%     n_eff     r_hat
                     I0[0]   1979.81   1958.23   1408.41     13.71   4334.50   1509.87      1.00
                     I0[1]   2012.22   1984.99   1445.75      1.34   4680.58   1273.64      1.00
                     I0[2]   1981.92   1898.81   1418.56     12.30   4674.62   1306.60      1.00
                     I0[3]   1992.46   2056.71   1300.74      1.39   4868.99   1460.03      1.00
                     I0[4]   2040.80   1992.46   1423.00      1.78   4886.50   1679.60      1.00
                     I0[5]   2132.70   2046.15   1468.24      2.48   4639.83   1803.11      1.00
                     I0[6]   2314.07   2422.40   1507.71      3.18   5544.54   1509.68      1.00
                     I0[7]   2840.31   3070.00   1780.52      0.80   7057.53   1413.80      1.00
                     I0[8]   4251.38   3904.76   3099.36      1.00   9837.86   1001.89      1.00
                     I0[9]  1

2024-08-27 10:41:18,501 - INFO - NY: Prior predictive simulation complete.
2024-08-27 10:41:20,386 - INFO - NY: Posterior predictive simulation complete.
2024-08-27 10:41:20,387 - INFO - Initializing CFAEPIM_Rt
2024-08-27 10:41:20,388 - INFO - Initializing CFAEPIM_Infections
2024-08-27 10:41:20,388 - INFO - Initializing CFAEPIM_Observation
2024-08-27 10:41:20,388 - INFO - Validating CFAEPIM_Observation parameters
2024-08-27 10:41:20,388 - INFO - Initializing alpha process
2024-08-27 10:41:20,388 - INFO - Initializing negative binomial process
2024-08-27 10:41:22,318 - INFO - NY: Posterior predictive forecasts complete.


Figure(600x500)


### Model Architecture

* Three components (`rt`, `observation`, `infection`)
* Unpooled across states. 
* Uses `pyrenew`, `numpyro`, `jax`.

From `pyrenew`, what do we use?

```python
import pyrenew.transformation as t # for scaled logit transform

from pyrenew.deterministic import DeterministicPMF # for generation interval

from pyrenew.latent import (
    InfectionInitializationProcess, # seeding process
    InitializeInfectionsFromVec, # seeding process
    logistic_susceptibility_adjustment, # for transforming infections
)

from pyrenew.metaclass import (
    DistributionalRV, # for random walk, infection process, and observation process
    Model, # for the cfaepim model itself
    RandomVariable, # for subclassing the components
    SampledValue, # model sampling return value
    TransformedRandomVariable, # for scaling Rt samples
)

# for observation process
from pyrenew.observation import NegativeBinomialObservation

# for Rt random walk process
from pyrenew.process import SimpleRandomWalkProcess

# for alpha process (instantaneous ascertainment rate, getting expected hosps.)
from pyrenew.regression import GLMPrediction 
```

#### Rt Component

```python
class CFAEPIM_Rt(RandomVariable):  # numpydoc ignore=GL08
    def __init__(
        self,
        intercept_RW_prior: numpyro.distributions,
        max_rt: float,
        gamma_RW_prior_scale: float,
        week_indices: ArrayLike,
    ):  
        logging.info("Initializing CFAEPIM_Rt")
        self.intercept_RW_prior = intercept_RW_prior
        self.max_rt = max_rt
        self.gamma_RW_prior_scale = gamma_RW_prior_scale
        self.week_indices = week_indices

    @staticmethod
    def validate(
        intercept_RW_prior: any,
        max_rt: any,
        gamma_RW_prior_scale: any,
        week_indices: any,
    ) -> None:  
        logging.info("Validating CFAEPIM_Rt parameters")
        if not isinstance(intercept_RW_prior, dist.Distribution):
            raise ValueError(
                f"intercept_RW_prior must be a numpyro distribution; was type {type(intercept_RW_prior)}"
            )
        if not isinstance(max_rt, (float, int)) or max_rt <= 0:
            raise ValueError(
                f"max_rt must be a positive number; was type {type(max_rt)}"
            )
        if (
            not isinstance(gamma_RW_prior_scale, (float, int))
            or gamma_RW_prior_scale <= 0
        ):
            raise ValueError(
                f"gamma_RW_prior_scale must be a positive number; was type {type(gamma_RW_prior_scale)}"
            )
        if not isinstance(week_indices, (np.ndarray, jnp.ndarray)):
            raise ValueError(
                f"week_indices must be an array-like structure; was type {type(week_indices)}"
            )

    def sample(self, n_steps: int, **kwargs) -> tuple:  # numpydoc ignore=GL08
        # sample the standard deviation for the random walk process
        sd_wt = numpyro.sample(
            "Wt_rw_sd", dist.HalfNormal(self.gamma_RW_prior_scale)
        )
        # Rt random walk process
        wt_rv = SimpleRandomWalkProcess(
            name="Wt",
            step_rv=DistributionalRV(
                name="rw_step_rv",
                dist=dist.Normal(0, sd_wt),
                reparam=LocScaleReparam(0),
            ),
            init_rv=DistributionalRV(
                name="init_Wt_rv",
                dist=self.intercept_RW_prior,
            ),
        )
        # transform Rt random walk w/ scaled logit
        transformed_rt_samples = TransformedRandomVariable(
            name="transformed_rt_rw",
            base_rv=wt_rv,
            transforms=t.ScaledLogitTransform(x_max=self.max_rt).inv,
        ).sample(n_steps=n_steps, **kwargs)
        # broadcast the Rt samples to daily values
        broadcasted_rt_samples = transformed_rt_samples[0].value[
            self.week_indices
        ]
        logging.debug(f"Broadcasted Rt samples: {broadcasted_rt_samples}")
        return broadcasted_rt_samples
```

#### Observation Component


```python
class CFAEPIM_Observation(RandomVariable):
    def __init__(
        self,
        predictors,
        alpha_prior_dist,
        coefficient_priors,
        nb_concentration_prior,
    ):  # numpydoc ignore=GL08
        logging.info("Initializing CFAEPIM_Observation")

        CFAEPIM_Observation.validate(
            predictors,
            alpha_prior_dist,
            coefficient_priors,
            nb_concentration_prior,
        )

        self.predictors = predictors
        self.alpha_prior_dist = alpha_prior_dist
        self.coefficient_priors = coefficient_priors
        self.nb_concentration_prior = nb_concentration_prior

        self._init_alpha_t()
        self._init_negative_binomial()

    def _init_alpha_t(self):
        logging.info("Initializing alpha process")
        self.alpha_process = GLMPrediction(
            name="alpha_t",
            fixed_predictor_values=self.predictors,
            intercept_prior=self.alpha_prior_dist,
            coefficient_priors=self.coefficient_priors,
            transform=t.SigmoidTransform().inv,
        )

    def _init_negative_binomial(self):
        logging.info("Initializing negative binomial process")
        self.nb_observation = NegativeBinomialObservation(
            name="negbinom_rv",
            concentration_rv=DistributionalRV(
                name="nb_concentration",
                dist=self.nb_concentration_prior,
            ),
        )

    @staticmethod
    def validate(
        predictors: any,
        alpha_prior_dist: any,
        coefficient_priors: any,
        nb_concentration_prior: any,
    ) -> None:
        logging.info("Validating CFAEPIM_Observation parameters")
        if not isinstance(predictors, (np.ndarray, jnp.ndarray)):
            raise TypeError(
                f"Predictors must be an array-like structure; was type {type(predictors)}"
            )
        if not isinstance(alpha_prior_dist, dist.Distribution):
            raise TypeError(
                f"alpha_prior_dist must be a numpyro distribution; was type {type(alpha_prior_dist)}"
            )
        if not isinstance(coefficient_priors, dist.Distribution):
            raise TypeError(
                f"coefficient_priors must be a numpyro distribution; was type {type(coefficient_priors)}"
            )
        if not isinstance(nb_concentration_prior, dist.Distribution):
            raise TypeError(
                f"nb_concentration_prior must be a numpyro distribution; was type {type(nb_concentration_prior)}"
            )

    def sample(
        self,
        infections: ArrayLike,
        inf_to_hosp_dist: ArrayLike,
        **kwargs,
    ) -> tuple:
        alpha_samples = self.alpha_process.sample()["prediction"]
        alpha_samples = alpha_samples[: infections.shape[0]]
        expected_hosp = (
            alpha_samples
            * jnp.convolve(infections, inf_to_hosp_dist, mode="full")[
                : infections.shape[0]
            ]
        )
        logging.debug(f"Alpha samples: {alpha_samples}")
        logging.debug(f"Expected hospitalizations: {expected_hosp}")
        return alpha_samples, expected_hosp
```

#### Infection Component


```python
class CFAEPIM_Infections(RandomVariable):
    def __init__(
        self,
        I0: ArrayLike,
        susceptibility_prior: numpyro.distributions,
    ):  # numpydoc ignore=GL08
        logging.info("Initializing CFAEPIM_Infections")

        self.I0 = I0
        self.susceptibility_prior = susceptibility_prior

    @staticmethod
    def validate(I0: any, susceptibility_prior: any) -> None:
        logging.info("Validating CFAEPIM_Infections parameters")
        if not isinstance(I0, (np.ndarray, jnp.ndarray)):
            raise TypeError(
                f"Initial infections (I0) must be an array-like structure; was type {type(I0)}"
            )

        if not isinstance(susceptibility_prior, dist.Distribution):
            raise TypeError(
                f"susceptibility_prior must be a numpyro distribution; was type {type(susceptibility_prior)}"
            )

    def sample(
        self, Rt: ArrayLike, gen_int: ArrayLike, P: float, **kwargs
    ) -> tuple:


        # get initial infections
        I0_samples = self.I0.sample()
        I0 = I0_samples[0].value

        logging.debug(f"I0 samples: {I0}")

        # reverse generation interval (recency)
        gen_int_rev = jnp.flip(gen_int)

        if I0.size < gen_int.size:
            raise ValueError(
                "Initial infections vector must be at least as long as "
                "the generation interval. "
                f"Initial infections vector length: {I0.size}, "
                f"generation interval length: {gen_int.size}."
            )
        recent_I0 = I0[-gen_int_rev.size :]

        # sample the initial susceptible population proportion S_{v-1} / P from prior
        init_S_proportion = numpyro.sample(
            "S_v_minus_1_over_P", self.susceptibility_prior
        )
        logging.debug(f"Initial susceptible proportion: {init_S_proportion}")

        # calculate initial susceptible population S_{v-1}
        init_S = init_S_proportion * P

        def update_infections(carry, Rt):  # numpydoc ignore=GL08
            S_t, I_recent = carry

            # compute raw infections
            i_raw_t = Rt * jnp.dot(I_recent, gen_int_rev)

            # apply the logistic susceptibility adjustment to a potential new incidence
            i_t = logistic_susceptibility_adjustment(
                I_raw_t=i_raw_t, frac_susceptible=S_t / P, n_population=P
            )

            # update susceptible population
            S_t -= i_t

            # update infections
            I_recent = jnp.concatenate([I_recent[:-1], jnp.array([i_t])])

            return (S_t, I_recent), i_t

        # initial carry state
        init_carry = (init_S, recent_I0)

        # scan to iterate over time steps and update infections
        (all_S_t, _), all_I_t = numpyro.contrib.control_flow.scan(
            update_infections, init_carry, Rt
        )

        logging.debug(f"All infections: {all_I_t}")
        logging.debug(f"All susceptibles: {all_S_t}")

        return all_I_t, all_S_t
```

### Reflections

(verbally provided but in short)

* PyRenew changes frequently, but the changes were not crippling.
* The model metaclassing & most of the "hard parts", e.g. random walk and linear predictor content, worked in a straightforward manner out of the box. The smaller pieces all worked fine. 
* Excited to see more collaboration with NumPyro developers.
* Excited to see further updates supporting model instantiation pre- and post-processing, including ETL and data visualization, coming in the future, possibly as a part of PyRenew or `cfa-forecasttools`.
* Working on evaluation of models.