# Pyrenew demo
This demo simulates some basic renewal process data and then fits to it using `pyrenew`.

You'll need to install `pyrenew` first. You'll also need working installations of `matplotlib`, `numpy`, `jax`, `numpyro`, and `polars`

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use("seaborn-v0_8-whitegrid")
mpl.rcParams['figure.dpi'] = 300
# mpl.rcParams["text.usetex"] = True
mpl.rcParams["font.size"] = 10
mpl.rcParams["axes.formatter.use_mathtext"] = True
mpl.rcParams["axes.grid"] = True
mpl.rcParams["axes.spines.right"] = False
mpl.rcParams["axes.spines.left"] = False
mpl.rcParams["axes.spines.top"] = False
mpl.rcParams["axes.spines.bottom"] = False
mpl.rcParams["legend.fancybox"] = True
mpl.rcParams["legend.frameon"] = True
mpl.rcParams["legend.framealpha"] = 1

import jax
import jax.numpy as jnp
import numpy as np
from numpyro.handlers import seed
import numpyro.distributions as dist

In [None]:
from pyrenew.processes import SimpleRandomWalkProcess

q = SimpleRandomWalkProcess(dist.Normal(0, 0.001))
with seed(rng_seed=np.random.randint(0,1000)):
    q_samp = q.sample(duration=100)
    
plt.plot(np.exp(q_samp))

In [None]:
from pyrenew.observations import (
    InfectionsObservation,
    HospitalizationsObservation,
    PoissonObservation
)

from pyrenew.models import HospitalizationsModel
from pyrenew.processes import RtRandomWalkProcess

# Initializing model parameters
infections_obs = InfectionsObservation(jnp.array([0.25, 0.25, 0.25, 0.25]))
Rt_process = RtRandomWalkProcess()
hosp_obs = HospitalizationsObservation(
    inf_hosp_int=jnp.array(
        [0, 0, 0,0,0,0,0,0,0,0,0,0,0, 0.25, 0.5, 0.1, 0.1, 0.05]
        ),
    hosp_dist=dist.Poisson
)

# Initializing the model
hospmodel = HospitalizationsModel(
    Rt_process     = Rt_process,
    infections_obs = infections_obs,
    hosp_obs       = hosp_obs
    )

In [None]:
with seed(rng_seed=np.random.randint(1, 60)):
    x = hospmodel.model(constants=dict(n_timepoints=30))
x

In [None]:
fig, ax = plt.subplots(nrows=3, sharex=True)
ax[0].plot(x)
ax[0].set_ylim([1/5, 5])
ax[1].plot(x[1])
ax[2].plot(x[3], 'o')
for axis in ax[:-1]:
    axis.set_yscale("log")

In [None]:
sim_dat={"observed_hospitalizations": x.samp_hosp}
constants = {"n_timepoints":len(x.samp_hosp)-1}

# from numpyro.infer import MCMC, NUTS
hospmodel.run(
    num_warmup=1000,
    num_samples=1000,
    random_variables=sim_dat,
    constants=constants,
    rng_key=jax.random.PRNGKey(54),
    )

In [None]:
x

In [None]:
a.print_summary()

In [None]:
from pyrenew.mcmcutils import spread_draws
samps = spread_draws(a.mcmc.get_samples(), [("Rt", "time")])

In [None]:
import numpy as np
import polars as pl
fig, ax = plt.subplots(figsize=[4, 5])

ax.plot(x[0])
samp_ids = np.random.randint(size=25, low=0, high=999)
for samp_id in samp_ids:
    sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col('time'))
    ax.plot(sub_samps.select("time").to_numpy(), 
            sub_samps.select("Rt").to_numpy(), color="darkblue", alpha=0.1)
ax.set_ylim([0.4, 1/.4])
ax.set_yticks([0.5, 1, 2])
ax.set_yscale("log")
