From 3b77175d7c2fe0e90bd7aa14edb3228e24439241 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 5 May 2026 16:44:02 -0400 Subject: [PATCH 1/3] rewrite on fitting and running the model --- .../tutorials/building_multisignal_models.qmd | 244 ++++++++---------- 1 file changed, 109 insertions(+), 135 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index d0b4d526..4291f83a 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -3,6 +3,12 @@ title: Building Multi-Signal Renewal Models format: gfm: code-fold: true + html: + toc: true + embed-resources: true + self-contained-math: true + code-fold: true + code-tools: true engine: jupyter jupyter: jupytext: @@ -617,112 +623,20 @@ The priors on `I0_rv` and on the ascertainment rates resolve this ambiguity. ## Fitting the Model to Data: `model.run()` -When you call `model.run()`, you supply two types of information: +When you call `model.run()`, you supply three types of information: +- **Model and population information** --- the fitting period, total population size, and subpopulation fractions - **Observation data** --- one data dictionary per registered observation process -- **Population structure** --- how the jurisdiction is divided into subpopulations - -### Shared Time Axis - -All observation data uses a **shared time axis** `[0, n_total)` where `n_total = n_init + n_days`. -This shared axis aligns observations with the internal infection vectors: - -- Index 0 corresponds to the first day of the initialization period -- Index `n_init` corresponds to the first day of actual observations -- Index `n_total - 1` corresponds to the last observation day - -The model provides helper methods to align your data with this shared axis: - -- `model.pad_observations(obs)` --- prepends `n_init` NaN values to dense observation vectors -- `model.shift_times(times)` --- adds `n_init` to sparse time indices - -### Observation Data by Signal Type +- **MCMC controls** --- basic settings for posterior sampling -Each observation process's `name` attribute becomes the keyword argument for passing data to `model.run()`: - -```python -builder.add_observation(hosp_obs) # hosp_obs.name="hospital" → hospital={...} -builder.add_observation(ww_obs) # ww_obs.name="wastewater" → wastewater={...} -``` - -#### Jurisdiction-level signals (dense) - -The jurisdiction-level hospital admissions data is specified as a `PopulationCounts` observations process with dense data padded to length `n_total`: - -```python -hospital = { - "obs": model.pad_observations(hosp_counts), # shape: (n_total,), NaN-padded -} -``` - -The `pad_observations` method prepends `n_init` NaN values. -NaN marks the initialization period where predictions exist but observations do not. -You can also use NaN to mark missing data within the observation period. - -#### Subpopulation-level signals (sparse) - -The subpopulation-level wastewater data is specified as a `Wastewater` observations process with sparse indexing on the shared time axis: - -```python -wastewater = { - "obs": jnp.array([...]), # observed log concentrations (n_obs,) - "times": model.shift_times(ww_times), # time indices on shared axis - "subpop_indices": jnp.array( - [...] - ), # which subpopulation (selects infection column) - "sensor_indices": jnp.array( - [...] - ), # which WWTP/lab pair (selects noise parameters) - "n_sensors": int, # total number of WWTP/lab pairs -} -``` - -The `shift_times` method adds `n_init` to convert from natural coordinates (0 = first observation day) to the shared time axis. - -**Understanding `subpop_indices`**: The latent process generates infections for all subpopulations as a matrix of shape `(T, n_subpops)`. -Each observation selects which column (subpopulation) it came from using `subpop_indices`. -This is how observation processes "know" which subpopulations they observe---the user specifies this mapping at sample/run time. - -A **subpopulation** is a portion of the jurisdiction's population (e.g., a catchment area). -A **sensor** is a measurement source --- typically a WWTP/lab pair --- that produces observations. -Multiple sensors can observe the same subpopulation (e.g., different labs processing samples from the same catchment), so `subpop_indices` and `sensor_indices` may differ. - -- `subpop_indices` links each observation to the appropriate infection column (0-indexed into the subpopulations) -- `sensor_indices` selects which sensor's noise parameters (mode and sd) to apply - -**Example**: A jurisdiction has 6 subpopulations (indices 0-5), where 5 have wastewater monitoring and 1 does not. -The `subpop_fractions` array has 6 elements. -If subpopulation 2 lacks wastewater monitoring, wastewater observations would have `subpop_indices` values only in {0, 1, 3, 4, 5}---never 2. -The monitored subpopulations need not be contiguous. -The latent process still generates infections for all 6 subpopulations; the wastewater observation just doesn't see subpopulation 2. - -### Population Structure - -Population structure is specified via a single array of fractions for all subpopulations: - -```python -model.run( - ..., - subpop_fractions=jnp.array([...]), # one fraction per subpopulation, must sum to 1 -) -``` - -This specifies 6 subpopulations with their population fractions. -The fractions must sum to 1.0. -The latent process generates infections for all 6 subpopulations. - -Which subpopulations each observation process "sees" is determined by the `subpop_indices` in the observation data, not by the population structure. -For example, if wastewater monitoring covers only 5 of the 6 subpopulations (say, all except subpopulation 2), the wastewater observation data would have `subpop_indices` values in {0, 1, 3, 4, 5} but never 2. -The monitored subpopulations can be any subset of {0, ..., n_subpops-1}. - -### Example `model.run()` Call +For this model, a complete call has the following structure: ```python model.run( num_warmup=500, num_samples=500, + rng_key=make_rng_key(), mcmc_args={"num_chains": 4, "progress_bar": False}, - # Model arguments (passed through to sample()) n_days_post_init=n_days, population_size=population_size, subpop_fractions=subpop_fractions, @@ -731,30 +645,35 @@ model.run( samples = model.mcmc.get_samples() ``` -where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. -For this example, such a data dictionary would have the following structure: +The `**obs_data` dictionary supplies the observation start date and one data dictionary per observation process. +The keys of the observation dictionaries are the names registered on the builder: `hospital` and `wastewater`. -```python -{ - "obs_start_date": ..., - "hospital": { - "obs": ... - }, - "wastewater": { - "obs": ... - "times": ... - "subpop_indices": ... - "sensor_indices": ... - "n_sensors": ... - }, -} -``` +### Model Time vs. Observation Time -## Running the Model +In the raw data, day 0 is the first day with observations. +PyRenew adds an initialization period before this date so the renewal process has enough past infections to generate infections during the observed period. + +This means the model has two related time scales: + +- **Observation time** starts at the first day in the data. +- **Model time** starts `n_init` days earlier. + +On the model time axis, indices `0` through `n_init - 1` are initialization days. +The first observed day is model index `n_init`, and the last observed day is model index `n_init + n_days - 1`. + +The model provides helper methods to convert observation data onto this model time axis: + +- `model.pad_observations(obs)` prepends `n_init` NaN values to dense daily observation vectors. +- `model.shift_times(times)` adds `n_init` to sparse observation time indices. + +The hospital admissions signal is dense: it has one value per observation day, so we pad it. +The wastewater signal is sparse: each row has an observation day, so we shift its time index. + +### Population Structure First we declare the population structure. We have 6 subpopulations, where 5 have wastewater monitoring and 1 does not. -The subpopulations with wastewater monitoring need not be contiguous indices---they could be any subset of {0, 1, ..., n_subpops-1}. +The subpopulations with wastewater monitoring need not be contiguous indices; they can be any subset of {0, 1, ..., n_subpops-1}. ```{python} #| label: population-structure @@ -779,9 +698,14 @@ print( print(f"Total population: {float(jnp.sum(subpop_fractions)):.0%}") ``` +The `subpop_fractions` array defines all subpopulations in the latent infection process, and its entries must sum to 1.0. +Which subpopulations each observation process sees is determined by the observation data. +For wastewater, `subpop_indices` links each observation to the appropriate subpopulation, while `sensor_indices` selects the sensor-specific noise parameters. + +### Preparing Observation Data + We define a function to prepare observation data using the model's helper methods to align with the shared time axis. -The returned dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process, keyed by the name registered on the builder (`hospital`, `wastewater`). -At call time the returned dict is unpacked with `**` (for example, `**obs_data_90` in the 90-day fit below), forwarding each entry as a keyword argument to `model.run()`. +The returned dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process. ```{python} #| label: prepare-observation-data @@ -825,10 +749,10 @@ def prepare_observation_data( ww_conc = ww_data["observed_conc"][ww_mask] ww_sensors = ww_data["site_ids"][ww_mask] - # Map wastewater sensors to subpopulation indices - # Each sensor is assigned to one of the monitored subpopulations. - # In practice, this mapping comes from your data (which WWTP serves which catchment). - # For this demo, we cycle sensors through the monitored subpopulations. + # Map wastewater sensors to subpopulation indices. + # In practice, this mapping comes from your data + # (which WWTP serves which catchment). For this demo, we cycle + # sensors through the monitored subpopulations. n_ww_sensors = ww_data["n_sites"] n_monitored = len(ww_monitored_subpops) sensor_to_subpop = { @@ -851,10 +775,50 @@ def prepare_observation_data( } ``` -### Fit: 90 Days +The returned dictionary has the following structure: -Putting this altogether, we align the data with the model time and call `model.run()`. -We run 4 sampler chains. +```python +{ + "obs_start_date": ..., + "hospital": { + "obs": ... + }, + "wastewater": { + "obs": ... + "times": ... + "subpop_indices": ... + "sensor_indices": ... + "n_sensors": ... + }, +} +``` + +At call time, this dictionary is unpacked with `**`, forwarding each entry as a keyword argument to `model.run()`. + +### MCMC Controls + +`model.run()` uses NumPyro's No-U-Turn Sampler (NUTS) through `numpyro.infer.MCMC`. +For this introductory example, the main controls are: + +- `num_warmup`, the number of warmup iterations +- `num_samples`, the number of posterior samples to keep +- `rng_key`, the random seed for sampling +- `mcmc_args`, optional NumPyro MCMC settings + +In the fits below, we use 4 chains and turn off the progress bar: + +```python +mcmc_args = {"num_chains": 4, "progress_bar": False} +``` + +See the [NumPyro MCMC reference](https://num.pyro.ai/en/stable/mcmc.html) for more advanced MCMC options. + +## Running the Model + +We now prepare the observation dictionaries described above and unpack them into `model.run()`. +The first fit uses 90 days of data. + +### Fit: 90 Days ```{python} #| label: fit-90-days @@ -895,7 +859,7 @@ print(f"Elapsed time: {elapsed_90:.1f} seconds") ``` We use [ArviZ](https://python.arviz.org/en/stable/) to assess MCMC convergence and mixing via the $\hat{R}$ statistic and effective sample size (ESS). -Before running these diagnostics, it is necessary to we trim the first `n_init` time steps from all time-series variables. +Before running these diagnostics, we trim the first `n_init` time steps from all time-series variables. Since the model cannot estimate latent infections until it has seen a full generation interval's worth of data, these early time steps have no meaningful epidemiological interpretation and therefore should be excluded from summaries and visualizations. ```{python} @@ -1224,18 +1188,28 @@ The high uncertainty in days 60-90 of the 90-day fit is exactly what we'd expect ## Summary -This tutorial demonstrated composing a multi-signal renewal model using `PyrenewBuilder`: +In this tutorial, we built a renewal model that combines hospital admissions and wastewater concentrations through a shared latent infection process. +The latent process describes infections through time and across subpopulations; each observation process describes how one data stream is generated from those infections. + +The main workflow was: + +1. Configure the latent infection process with `configure_latent`. +2. Add named observation processes with `add_observation`. +3. Build the model with `build`. +4. Prepare observation dictionaries whose keys match the observation process names. +5. Fit the model with `model.run`. -1. **Configure latent process** (`configure_latent`): generation interval, initial infections, temporal dynamics -2. **Add observation processes** (`add_observation`): each declares its infection resolution and gets a name for data binding -3. **Build and run** (`build`, `model.run`): the model routes infections to observations based on resolution and runs NUTS inference +A few details are especially important when adapting this pattern: -### Key Concepts +- Observation process names become the data argument names passed to `model.run`, such as `hospital={...}` and `wastewater={...}`. +- Dense daily observations are padded with `model.pad_observations`. +- Sparse observations use shifted time indices from `model.shift_times`. +- `subpop_fractions` defines the latent subpopulations; `subpop_indices` tells an observation process which subpopulation each measurement observes. +- The model time axis starts before the first observed day because PyRenew adds an initialization period for the renewal process. -- **Two-part structure**: Renewal models separate latent infection dynamics from observation processes -- **Infection resolution**: Observation processes declare whether they need aggregate or subpop-level infections -- **Data routing**: `PyrenewBuilder` automatically routes infection trajectories to the appropriate observation processes -- **Time alignment**: Observations must be offset by `n_initialization_points` to align with model time +The 90-day and 180-day fits also show why the end of the observation window matters. +Future observations help constrain past latent infections, so posterior uncertainty is usually largest near the most recent observed days. +That same edge uncertainty is what carries forward when the model is used for forecasting. ### Next Steps From 964a544a5e53147090079df94bd93178305a1bdb Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Mon, 11 May 2026 12:57:49 -0400 Subject: [PATCH 2/3] changes per tutorial review --- .../tutorials/building_multisignal_models.qmd | 393 +++++------------- 1 file changed, 114 insertions(+), 279 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 8a5254fb..a51367e7 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -3,12 +3,6 @@ title: Building Multi-Signal Renewal Models format: gfm: code-fold: true - html: - toc: true - embed-resources: true - self-contained-math: true - code-fold: true - code-tools: true engine: jupyter jupyter: jupytext: @@ -75,7 +69,7 @@ from pyrenew.randomvariable import DistributionalVariable from pyrenew.latent import ( SubpopulationInfections, - AR1, + DifferencedAR1, RandomWalk, StepwiseTemporalProcess, WeeklyTemporalProcess, @@ -94,9 +88,6 @@ from pyrenew.observation import ( from pyrenew.time import MMWR_WEEK ``` -Temporal process parameters are `RandomVariable`s. -This tutorial uses `DeterministicVariable` wrappers to keep those hyperparameters fixed. - ## Overview Renewal models in PyRenew combine two types of components: @@ -105,7 +96,7 @@ Renewal models in PyRenew combine two types of components: 2. **Observation processes**: Transform latent infections into observable signals (hospital admissions, wastewater concentrations, etc.) by applying delays, ascertainment, and noise -A **multi-signal model** combines multiple observation processes---each representing a different data stream, e.g., hospital admissions, emergency deparatment visits, wastewater concentrations, which stem from the same underlying latent infection process. +A **multi-signal model** combines multiple observation processes---each representing a different data stream, e.g., hospital admissions, emergency department visits, wastewater concentrations, which stem from the same underlying latent infection process. By jointly modeling these signals, we can improve estimation and prediction of the time-varying reproduction number $\mathcal{R}(t)$. Such a model must: @@ -143,7 +134,7 @@ This tutorial demonstrates building a multi-signal renewal model using: In this tutorial, we build a model that jointly fits two data streams to a shared latent infection process: -- **Hospital admissions** --- jurisdiction-level counts that reflect *total* infections across all subpopulations, delayed and underascertained +- **Hospital admissions** --- jurisdiction-level counts that reflect a delayed and partially observed subset of total infections across all subpopulations - **Wastewater concentrations** --- site-level measurements from a subset of subpopulations (catchment areas), reflecting viral shedding and dilution The diagram below shows how data flows through the model. @@ -248,7 +239,7 @@ log_rt_time_0_rv = DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5) We configure two temporal processes: -- **Jurisdiction-level** (`baseline_rt_process`): AR(1) process for the baseline $\mathcal{R}(t)$ +- **Jurisdiction-level** (`baseline_rt_process`): DifferencedAR(1) process for the baseline $\mathcal{R}(t)$ - **Subpopulation-level** (`subpop_rt_deviation_process`): RandomWalk for subpopulation deviations The RandomWalk allows flexible evolution of subpopulation-specific transmission without mean reversion. @@ -256,10 +247,10 @@ The RandomWalk allows flexible evolution of subpopulation-specific transmission ```{python} #| label: temporal-processes -# AR1 provides mean-reverting behavior for baseline Rt -baseline_rt_process = AR1( - autoreg_rv=DeterministicVariable("autoreg", 0.9), - innovation_sd_rv=DeterministicVariable("innovation_sd", 0.05), +# DifferencedAR1 allows persistent trends while stabilizing the growth rate. +baseline_rt_process = DifferencedAR1( + autoreg_rv=DeterministicVariable("autoreg", 0.5), + innovation_sd_rv=DeterministicVariable("innovation_sd", 0.01), ) # RandomWalk allows flexible subpopulation deviations @@ -277,12 +268,12 @@ This separates three model choices: - **Model time axis**: the daily axis used by the renewal equation and delay convolutions - **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions -The AR(1) process above samples one value per model day: +The DifferencedAR(1) process above samples one value per model day: ```python -baseline_rt_process = AR1( - autoreg_rv=DeterministicVariable("autoreg", 0.9), - innovation_sd_rv=DeterministicVariable("innovation_sd", 0.05), +baseline_rt_process = DifferencedAR1( + autoreg_rv=DeterministicVariable("autoreg", 0.5), + innovation_sd_rv=DeterministicVariable("innovation_sd", 0.01), ) ``` @@ -291,9 +282,9 @@ The wrapper samples a weekly trajectory and broadcasts it to the daily model axi ```python weekly_baseline_rt_process = WeeklyTemporalProcess( - inner=AR1( - autoreg_rv=DeterministicVariable("autoreg", 0.9), - innovation_sd_rv=DeterministicVariable("innovation_sd", 0.05), + inner=DifferencedAR1( + autoreg_rv=DeterministicVariable("autoreg", 0.5), + innovation_sd_rv=DeterministicVariable("innovation_sd", 0.01), ), start_dow=MMWR_WEEK, ) @@ -714,82 +705,68 @@ print(f"Total population: {float(jnp.sum(subpop_fractions)):.0%}") The `subpop_fractions` array defines all subpopulations in the latent infection process, and its entries must sum to 1.0. Which subpopulations each observation process sees is determined by the observation data. -For wastewater, `subpop_indices` links each observation to the appropriate subpopulation, while `sensor_indices` selects the sensor-specific noise parameters. + +Measurement data typically exhibits sensor-level variability: different instruments, labs, or sampling locations may have systematic biases and different precision levels. +The wastewater observations record both the wastewater treatment plant from which the sample was collected and the laboratory at which the sample was processed. +For wastewater, a "sensor" is a WWTP/lab pair - the combination of treatment plant and laboratory associated with sensor-specific bias and measurement variability. +Because the wastewater component's `HierarchicalNormalNoise` samples one mode and one standard deviation parameter. +Therefore, the wastewater signal's data dictionary includes: + +- `subpop_indices` links each observation to the appropriate subpopulation +- `sensor_indices` selects the sensor-specific noise parameters. +- `n_sensors` the total number of sensors, used to size the sensor-level noise parameters. ### Preparing Observation Data -We define a function to prepare observation data using the model's helper methods to align with the shared time axis. -The returned dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process. +The observation data dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process. +The entries in the sub-dictionaries are forwarded as arguments to each observation process's `sample` method. + +The two observation streams are aligned differently because they are represented differently. +Hospital admissions are a dense daily series: every day in the fitting window has a position in the vector, so `model.pad_observations()` prepends `n_init` missing values for the initialization period. +Wastewater measurements are sparse rows: only sampled days appear in the table, so the observations themselves are not padded. +Instead, `model.shift_times()` adds `n_init` to each wastewater time index so those rows point to the correct days on the model's shared time axis. ```{python} #| label: prepare-observation-data -def prepare_observation_data( - model, - n_days_fit, - obs_start_date, - hosp_admits, - ww_data, - ww_monitored_subpops, -): - """ - Prepare observation data for fitting. - - Uses model.pad_observations() and model.shift_times() to align data - with the shared time axis. - - Parameters - ---------- - model : MultiSignalModel - The model (provides padding/shifting helpers) - n_days_fit : int - Number of days to include in fit - obs_start_date : date - Date of first observation - hosp_admits : array - Hospital admissions time series - ww_data : dict - Wastewater data dictionary - ww_monitored_subpops : array - Indices of subpopulations that have wastewater monitoring. - These are the valid values for subpop_indices in wastewater data. - """ - # Hospital: dense, NaN-padded to length n_total - hosp_obs = model.pad_observations(hosp_admits[:n_days_fit]) - - # Wastewater: sparse, times shifted to shared axis - ww_mask = ww_data["time_indices"] < n_days_fit - ww_times = model.shift_times(ww_data["time_indices"][ww_mask]) - ww_conc = ww_data["observed_conc"][ww_mask] - ww_sensors = ww_data["site_ids"][ww_mask] - - # Map wastewater sensors to subpopulation indices. - # In practice, this mapping comes from your data - # (which WWTP serves which catchment). For this demo, we cycle - # sensors through the monitored subpopulations. - n_ww_sensors = ww_data["n_sites"] - n_monitored = len(ww_monitored_subpops) - sensor_to_subpop = { - i: int(ww_monitored_subpops[i % n_monitored]) for i in range(n_ww_sensors) - } - ww_subpop_indices = jnp.array([sensor_to_subpop[int(s)] for s in ww_sensors]) - - return { - "obs_start_date": obs_start_date, - "hospital": { - "obs": hosp_obs, - }, - "wastewater": { - "obs": ww_conc, - "times": ww_times, - "subpop_indices": ww_subpop_indices, - "sensor_indices": ww_sensors, - "n_sensors": n_ww_sensors, - }, - } +n_days_90 = 90 + +# Hospital: dense, NaN-padded to length n_total +hosp_obs = model.pad_observations(hosp_admits[:n_days_90]) + +# Wastewater: sparse, times shifted to shared axis +ww_mask = ca_ww_data["time_indices"] < n_days_90 +ww_times = model.shift_times(ca_ww_data["time_indices"][ww_mask]) +ww_conc = ca_ww_data["observed_conc"][ww_mask] +ww_sites = ca_ww_data["site_ids"][ww_mask] + +# Map wastewater sensors to subpopulation indices. +# In practice, this mapping comes from your data +# (which WWTP serves which catchment). For this demo, we cycle +# sensors through the monitored subpopulations. +n_ww_sites = ca_ww_data["n_sites"] +n_monitored = len(ww_monitored_subpops) +sensor_to_subpop = { + i: int(ww_monitored_subpops[i % n_monitored]) for i in range(n_ww_sites) +} +ww_subpop_indices = jnp.array([sensor_to_subpop[int(s)] for s in ww_sites]) + +obs_data_90 = { + "obs_start_date": obs_start_date, + "hospital": { + "obs": hosp_obs, + }, + "wastewater": { + "obs": ww_conc, + "times": ww_times, + "subpop_indices": ww_subpop_indices, + "sensor_indices": ww_sites, + "n_sensors": n_ww_sites, + }, +} ``` -The returned dictionary has the following structure: +The observation dictionary has the following structure: ```python { @@ -819,7 +796,7 @@ For this introductory example, the main controls are: - `rng_key`, the random seed for sampling - `mcmc_args`, optional NumPyro MCMC settings -In the fits below, we use 4 chains and turn off the progress bar: +For this fit, we use 4 chains and turn off the progress bar: ```python mcmc_args = {"num_chains": 4, "progress_bar": False} @@ -829,8 +806,8 @@ See the [NumPyro MCMC reference](https://num.pyro.ai/en/stable/mcmc.html) for mo ## Running the Model -We now prepare the observation dictionaries described above and unpack them into `model.run()`. -The first fit uses 90 days of data. +We now unpack the observation dictionary prepared above into `model.run()`. +The fit uses 90 days of data. ### Fit: 90 Days @@ -840,16 +817,6 @@ The first fit uses 90 days of data. # Clear JAX caches to avoid interference from earlier cells jax.clear_caches() -n_days_90 = 90 -obs_data_90 = prepare_observation_data( - model, - n_days_90, - obs_start_date, - hosp_admits, - ca_ww_data, - ww_monitored_subpops, -) - print(f"Fitting model with {n_days_90} days of data...") print(" This may take a few minutes...") @@ -1003,202 +970,70 @@ plot_df_90["signal"] = pd.Categorical( ) ``` -### Fit: 180 Days - -```{python} -#| label: fit-180-days - -# Clear JAX caches to avoid interference -jax.clear_caches() +### Month-by-month uncertainty -n_days_180 = 180 -obs_data_180 = prepare_observation_data( - model, - n_days_180, - obs_start_date, - hosp_admits, - ca_ww_data, - ww_monitored_subpops, -) - -print(f"Fitting model with {n_days_180} days of data...") -print(" This may take a few minutes...") - -start_time = time.time() -model.run( - num_warmup=1000, - num_samples=500, - rng_key=make_rng_key(), - mcmc_args={"num_chains": 4, "progress_bar": False}, - n_days_post_init=n_days_180, - population_size=population_size, - subpop_fractions=subpop_fractions, - **obs_data_180, -) -# Block until sampling completes for accurate timing -samples_180 = model.mcmc.get_samples() -jax.block_until_ready(samples_180) -elapsed_180 = time.time() - start_time -print(f"Elapsed time: {elapsed_180:.1f} seconds") -``` - -We check the model fit, as before. +The posterior ribbon above shows the infection trajectory and its 90% credible interval, but the ribbon width itself is hard to compare across days. +To focus directly on uncertainty, we compute the daily 90% interval width and summarize it over each 30-day period. ```{python} -#| label: arviz-diagnostics-180 - -# ArviZ diagnostics for 180-day fit - -idata_180 = az.from_numpyro( - model.mcmc, - dims={ - "latent_infections": ["time"], - "SubpopulationInfections::infections_aggregate": ["time"], - "SubpopulationInfections::log_rt_baseline": ["time", "dummy"], - "SubpopulationInfections::rt_baseline": ["time", "dummy"], - "SubpopulationInfections::rt_subpop": ["time", "subpop"], - "SubpopulationInfections::subpop_deviations": ["time", "subpop"], - "latent_infections_by_subpop": ["time", "subpop"], - "hospital_predicted": ["time"], - "wastewater_predicted": ["time", "subpop"], - }, -) +#| label: monthly-CI -idata_180_trimmed = idata_180.map_over_datasets(trim_time) -az.summary(idata_180_trimmed, var_names=["latent_infections", "hospital_predicted"]) -``` - -```{python} -#| label: extract-quantiles-180 - -latent_inf = idata_180_trimmed.posterior["latent_infections"] - -quantiles_180 = { - "q05": latent_inf.quantile(0.05, dim=["chain", "draw"]).values, - "q50": latent_inf.quantile(0.50, dim=["chain", "draw"]).values, - "q95": latent_inf.quantile(0.95, dim=["chain", "draw"]).values, -} - -ci_width_180 = quantiles_180["q95"] - quantiles_180["q05"] -print(f"Posterior summary for {n_days_180} days:") -print(f" Mean 90% CI width: {ci_width_180.mean():,.0f} infections") -print(f" Median infections (day 90): {quantiles_180['q50'][90]:,.0f}") -``` - -```{python} -#| label: fig-posterior-180 -#| fig-cap: Posterior latent infections and observed hospitalizations (180 -#| days). +ci_width_90 = quantiles_90["q95"] - quantiles_90["q05"] -# Visualize posterior latent infections and observed hospitalizations (180 days) -infections_df_180 = pd.DataFrame( +ci_width_df = pd.DataFrame( { - "day": np.arange(n_days_180), - "median": quantiles_180["q50"], - "q05": quantiles_180["q05"], - "q95": quantiles_180["q95"], - "signal": "Latent Infections", + "day": np.arange(n_days_90), + "ci_width": ci_width_90, + "period": pd.cut( + np.arange(n_days_90), + bins=[-1, 29, 59, 89], + labels=["Days 0-29", "Days 30-59", "Days 60-89"], + ), } ) -# Add 14-day moving average to smooth noisy daily admissions -hosp_raw_180 = np.array(hosp_admits[:n_days_180], dtype=float) -hosp_ma_180 = pd.Series(hosp_raw_180).rolling(window=14, center=True).mean().values - -hosp_df_180 = pd.DataFrame( - { - "day": np.arange(n_days_180), - "median": hosp_ma_180, - "raw": hosp_raw_180, - "q05": np.nan, - "q95": np.nan, - "signal": "Hospital Admissions (14-day MA)", - } +period_summary_df = ( + ci_width_df.groupby("period", observed=True) + .agg( + start=("day", "min"), + end=("day", "max"), + mean_ci_width=("ci_width", "mean"), + ) + .reset_index() ) -plot_df_180 = pd.concat([infections_df_180, hosp_df_180], ignore_index=True) -plot_df_180["signal"] = pd.Categorical( - plot_df_180["signal"], - categories=["Hospital Admissions (14-day MA)", "Latent Infections"], - ordered=True, -) +print("\nMean 90% interval width by time period:") +for row in period_summary_df.itertuples(index=False): + print(f" {row.period}: {row.mean_ci_width:,.0f}") ( - p9.ggplot(plot_df_180, p9.aes(x="day")) - + p9.geom_ribbon( - p9.aes(ymin="q05", ymax="q95"), - fill="steelblue", - alpha=0.3, + p9.ggplot(ci_width_df, p9.aes(x="day", y="ci_width")) + + p9.geom_line(color="gray", alpha=0.6, size=0.7) + + p9.geom_segment( + data=period_summary_df, + mapping=p9.aes( + x="start", + xend="end", + y="mean_ci_width", + yend="mean_ci_width", + color="period", + ), + size=1.6, ) - + p9.geom_point( - p9.aes(y="raw"), - color="gray", - alpha=0.3, - size=1, - ) - + p9.geom_line( - p9.aes(y="median"), - color="darkblue", - size=1, - ) - + p9.facet_wrap("~signal", ncol=1, scales="free_y") - + p9.scale_y_log10() + p9.labs( x="Day", - y="Count", - title="Posterior Latent Infections vs Observed Hospitalizations (180 days)", + y="90% interval width", + color="Period", + title="Posterior uncertainty over the 90-day fit", ) + theme_tutorial ) ``` -### Comparing 90-Day vs 180-Day Fits - -Comparing the two fits reveals where uncertainty reduction occurs---and why it matters for forecasting. - -```{python} -#| label: compare-fits - -# Compare CI widths for the overlapping 90-day period -ci_width_90_overlap = quantiles_90["q95"] - quantiles_90["q05"] -ci_width_180_overlap = ( - quantiles_180["q95"][:n_days_90] - quantiles_180["q05"][:n_days_90] -) - -# Compute difference in CI widths -ci_diff = ci_width_90_overlap - ci_width_180_overlap -ci_ratio = ci_width_90_overlap / ci_width_180_overlap - -print("CI Width Comparison (first 90 days):") -print(f" 90-day fit mean CI width: {ci_width_90_overlap.mean():,.0f} infections") -print(f" 180-day fit mean CI width: {ci_width_180_overlap.mean():,.0f} infections") -print(f" Mean difference: {ci_diff.mean():,.0f} infections") -print(f" Mean ratio (90/180): {ci_ratio.mean():.2f}x") -print( - f"\nThe 180-day fit has {(1 - ci_width_180_overlap.mean() / ci_width_90_overlap.mean()) * 100:.1f}% narrower CIs on average" -) - -# Per-day comparison -print("\nCI width by time period:") -for start, end, label in [ - (0, 30, "Days 0-30"), - (30, 60, "Days 30-60"), - (60, 90, "Days 60-90"), -]: - mean_90 = ci_width_90_overlap[start:end].mean() - mean_180 = ci_width_180_overlap[start:end].mean() - reduction = (1 - mean_180 / mean_90) * 100 - print( - f" {label}: 90-day={mean_90:,.0f}, 180-day={mean_180:,.0f}, reduction={reduction:.1f}%" - ) -``` - -Notice that the uncertainty reduction is concentrated in days 60-90---the final month of the 90-day window. -Earlier periods (days 0-60) show little change because both fits have sufficient future data to constrain those estimates. - This pattern has a direct implication for forecasting: **renewal models are most uncertain at the edge of the observation window**. -Future observations constrain past latent infections through the renewal equation, but when predicting beyond available data, this constraint disappears. -The high uncertainty in days 60-90 of the 90-day fit is exactly what we'd expect when forecasting 30 days ahead---there's no future signal to anchor the estimates. +The daily line shows local changes in posterior uncertainty, while the monthly means make the broader edge effect easier to see. +Future observations constrain past latent infections through the renewal equation and observation delays, but the most recent days have less future signal available to anchor them. +That edge uncertainty is the same uncertainty that carries forward when the model is used for forecasting. ## Summary @@ -1221,7 +1056,7 @@ A few details are especially important when adapting this pattern: - `subpop_fractions` defines the latent subpopulations; `subpop_indices` tells an observation process which subpopulation each measurement observes. - The model time axis starts before the first observed day because PyRenew adds an initialization period for the renewal process. -The 90-day and 180-day fits also show why the end of the observation window matters. +The 90-day fit shows why the end of the observation window matters. Future observations help constrain past latent infections, so posterior uncertainty is usually largest near the most recent observed days. That same edge uncertainty is what carries forward when the model is used for forecasting. From 16f9c3d220e9d15a9a3e1b7475d1ba355f6ed222 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 12 May 2026 17:30:29 -0400 Subject: [PATCH 3/3] changes per code review --- .../tutorials/building_multisignal_models.qmd | 32 ++++++++----------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index a51367e7..7734d7c1 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -46,12 +46,6 @@ from datetime import date warnings.filterwarnings("ignore") from _tutorial_theme import theme_tutorial - - -def make_rng_key(): - """Generate a time-based random seed.""" - seed = int(time.time() * 1000) % (2**32) - return random.PRNGKey(seed) ``` ```{python} @@ -194,7 +188,7 @@ Latent infection processes implement the renewal equation to generate infection All latent processes share common components: - **Generation interval**: PMF for secondary infection timing -- **Initial infections (I0)**: Starting condition for the renewal process +- **Initial infections ($I(0)$)**: Starting condition for the renewal process - **Temporal dynamics**: How $\mathcal{R}(t)$ evolves over time ### Generation Interval @@ -208,9 +202,8 @@ covid_gen_int = [0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02] gen_int_pmf = jnp.array(covid_gen_int) gen_int_rv = DeterministicPMF("gen_int", gen_int_pmf) -# Mean generation time days = np.arange(len(gen_int_pmf)) -print(f"Generation interval length: {len(gen_int_pmf)} days") +print(f"Generation interval: {gen_int_pmf}") ``` ### I0: Initial Infections @@ -227,7 +220,7 @@ I0_rv = DistributionalVariable("I0", dist.Beta(1, 100)) ### Log Rt at time $0$ -We place a prior on the log(Rt) at time $0$, centered at 0.0 (Rt = 1.0) with moderate uncertainty: +We place a prior on the log $\mathcal{R}(t)$ at time $0$, centered at $0.0$ ($\mathcal{R}(t) = 1.0$) with moderate uncertainty: ```{python} #| label: log-rt-time-0 @@ -259,7 +252,7 @@ subpop_rt_deviation_process = RandomWalk( ) ``` -### Choosing the Rt Parameter Cadence +### Choosing the $\mathcal{R}(t)$ Parameter Cadence The renewal equation is evaluated on the model's daily time axis, but the temporal process for $\mathcal{R}(t)$ does not have to sample a new parameter every day. This separates three model choices: @@ -268,7 +261,7 @@ This separates three model choices: - **Model time axis**: the daily axis used by the renewal equation and delay convolutions - **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions -The DifferencedAR(1) process above samples one value per model day: +The `DifferencedAR1` process above samples one value per model day: ```python baseline_rt_process = DifferencedAR1( @@ -502,10 +495,10 @@ if ww_n_sites > 5: ``` Wastewater observations are site-level: each measurement is associated with a specific measurement site. -The Wastewater observation process uses LogNormalNoise, which takes hierarchical priors for the site-level mode and standard deviation parameters. +The `Wastewater` observation process uses `LogNormalNoise`, which takes hierarchical priors for the site-level mode and standard deviation parameters. This enables partial pooling across measurement sites. -Here we specify HierarchicalNormalPrior for the site-level mode and GammaGroupSdPrior for the standard deviation. +Here we specify `HierarchicalNormalPrior` for the site-level mode and `GammaGroupSdPrior` for the standard deviation. ```{python} #| label: wastewater-obs-process @@ -640,7 +633,7 @@ For this model, a complete call has the following structure: model.run( num_warmup=500, num_samples=500, - rng_key=make_rng_key(), + rng_key=random.PRNGKey(42), mcmc_args={"num_chains": 4, "progress_bar": False}, n_days_post_init=n_days, population_size=population_size, @@ -784,7 +777,9 @@ The observation dictionary has the following structure: } ``` -At call time, this dictionary is unpacked with `**`, forwarding each entry as a keyword argument to `model.run()`. +At call time, we unpack this dictionary with `**`, forwarding each entry as a keyword argument to `model.run()`. +Note that we could also have manually passed each dictionary entry to `model.run()`. +We chose to create (and then unpack) an `obs_data` dictionary to keep our data organized in one place. ### MCMC Controls @@ -824,7 +819,7 @@ start_time = time.time() model.run( num_warmup=500, num_samples=500, - rng_key=make_rng_key(), + rng_key=random.PRNGKey(42), mcmc_args={"num_chains": 4, "progress_bar": False}, n_days_post_init=n_days_90, population_size=population_size, @@ -1020,11 +1015,12 @@ for row in period_summary_df.itertuples(index=False): ), size=1.6, ) + + p9.scale_y_log10() + p9.labs( x="Day", y="90% interval width", color="Period", - title="Posterior uncertainty over the 90-day fit", + title="Posterior uncertainty in latent infections over the 90-day fit", ) + theme_tutorial )