# Getting started with `pyrenew`

This notebook illustrates two features of `pyrenew`: (a) the set of included `RandomProcess`es, and (b) model composition.

## Hospitalizations model

`pyrenew` has four main components:

- Utilitiy and math functions
- The `processes` sub-module
- The `observations` sub-module
- The `models` sub-module

Both `processes` and `observations` contain classes that inherit from the meta class `RandomProcess`. The classes under `models` inherit from the meta class `Model`.

```{mermaid}
flowchart TB

    subgraph randprocmod["Processes module"]
        direction TB
        simprw["SimpleRandomWalkProcess"]
        rtrw["RtRandomWalkProcess"]        
    end

    subgraph obsmod["Observations module"]
        direction TB
        pois["PoissonObservation"]
        hosp_obs["HospitalizationsObservation"]
    end

    subgraph models["Models module]
        direction TB
        basic["BasicRenewalModel"]
        hosp["HospitalizationsModel"]
    end

    rp(("RandomProcess")) --> |Inherited by| simprw
    rp -->|Inherited by| rtrw
    rp -->|Inherited by| pois
    rp -->|Inherited by| hosp_obs


    model(("Model")) -->|Inherited by| basic

    simprw -->|Composes| rtrw
    rtrw -->|Composes| basic
    basic -->|Inherited by| hosp


    pois -->|Composes| hosp
    hosp_obs -->|Composes| hosp

    %% Metaclasses
    classDef Metaclass color:black,fill:white
    class rp,model Metaclass

    %% Random process
    classDef Randproc fill:purple,color:white
    class rtrw,simprw Randproc

    %% Models
    classDef Models fill:teal,color:white
    class basic,hosp Models
```

We start by loading the needed components to build a basic renewal model:

In [None]:
import jax.numpy as jnp
import numpy as np
import numpyro as npro
from pyrenew.models import BasicRenewalModel
from pyrenew.observations import (
    InfectionsObservation,
    PoissonObservation,
)

In the basic renewal model we can define two components: the infections component and the Rt process. The infections component (passed as `infections_obs`) should be an instance of `RandomProcess`. Here we will use the `InfectionsObservation` random process.

Here, the `InfectionsObservation` is built assuming an underlying Poisson observation process. Notice that during instantiation of the `PoissonObservation` (which is also an instance of `RandomProcess`,) we can specify a variable associated with observed infections. Ultimately, if available, this is passed to `numpyro.sample(obs=...)`; otherwise, we can treat it as a latent variable.

In [None]:

# Creating the model and simulating data
infections_obs = InfectionsObservation(
    gen_int=jnp.array([0.25, 0.25, 0.25, 0.25]),
    inf_observation_model=PoissonObservation(
        rate_varname='infections_mean',
        counts_varname='infections_obs',
        )
    )

With observation process for the latent infections, we can build the basic renewal model, and generate a sample calling the `model()` method:

In [None]:

model1 = BasicRenewalModel(infections_obs=infections_obs)

np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 60)):
    sim_data = model1.model(constants=dict(n_timepoints=30))

sim_data

The `model()` method of the `BasicRenewalModel` returns a list composed of the `Rt` and `infections` sequences.

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(range(0, 31), sim_data[0])
axs[0].set_ylabel('Rt')

# Infections plot
axs[1].plot(range(0, 31), sim_data[1])
axs[1].set_ylabel('Infections')

fig.suptitle('Basic renewal model')
fig.supxlabel('Time')
plt.tight_layout()
plt.show()

Let's see how the estimation would go

In [None]:
import jax

model_data = {'n_timepoints': len(sim_data[1])-1}

model1.run(
    num_warmup=2000,
    num_samples=1000,
    random_variables=dict(infections_obs=sim_data[1]),
    constants=model_data,
    rng_key=jax.random.PRNGKey(54)
    )

Now, let's investigate the output

In [None]:
import polars as pl
samps = model1.spread_draws([('Rt', 'time')])

fig, ax = plt.subplots(figsize=[4, 5])

ax.plot(sim_data[0])
samp_ids = np.random.randint(size=25, low=0, high=999)
for samp_id in samp_ids:
    sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col('time'))
    ax.plot(sub_samps.select("time").to_numpy(), 
            sub_samps.select("Rt").to_numpy(), color="darkblue", alpha=0.1)
ax.set_ylim([0.4, 1/.4])
ax.set_yticks([0.5, 1, 2])
ax.set_yscale("log")