From 084c4e7e3cfddc4e1cde3f8c922bcc8eda0d20c5 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 21 Apr 2026 17:53:15 -0400 Subject: [PATCH 01/17] updated observation processes for aggregated count data; unit tests passing --- pyrenew/observation/base.py | 162 ++++++ pyrenew/observation/count_observations.py | 485 ++++++++++++++---- test/conftest.py | 182 +++++++ test/test_observation_counts.py | 593 +++++++++++++++++++++- test/test_observation_validation.py | 253 ++++++++- test/test_pyrenew_builder.py | 15 +- 6 files changed, 1556 insertions(+), 134 deletions(-) diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 3ddcfb1a..c1906aea 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -198,6 +198,113 @@ def _validate_dow_effect( if jnp.any(dow_effect < 0): raise ValueError(f"{param_name} must have non-negative values") + def _validate_aggregation_params( + self, + aggregation_period: int, + period_end_dow: int | None, + ) -> None: + """ + Validate temporal-aggregation constructor parameters. + + Checks that ``aggregation_period`` is an integer in + ``{1, 7}``, and that ``period_end_dow`` is an integer in + ``{0, ..., 6}`` (0=Monday, 6=Sunday, ISO convention) when + ``aggregation_period == 7``. ``period_end_dow`` is ignored + when ``aggregation_period == 1``. + + Parameters + ---------- + aggregation_period + Width of the reporting period in fundamental time units. + period_end_dow + Day-of-week index of each period's final day. + + Raises + ------ + ValueError + If ``aggregation_period`` is not in ``{1, 7}``, or if + ``period_end_dow`` is missing or out of range when + ``aggregation_period == 7``. + + Notes + ----- + ``period_end_dow`` is a weekly-specific anchor; generalizing + beyond ``aggregation_period == 7`` will require a different + anchor representation. + """ + if not isinstance(aggregation_period, int) or aggregation_period < 1: + raise ValueError( + "aggregation_period must be a positive integer, " + f"got {aggregation_period!r}" + ) + if aggregation_period not in (1, 7): + raise ValueError( + f"aggregation_period must be one of {{1, 7}}, got {aggregation_period}" + ) + if aggregation_period == 7: + if period_end_dow is None: + raise ValueError( + "period_end_dow is required when aggregation_period == 7" + ) + if not isinstance(period_end_dow, int) or not (0 <= period_end_dow <= 6): + raise ValueError( + "period_end_dow must be an integer in {0, ..., 6} " + f"(0=Monday, 6=Sunday), got {period_end_dow!r}" + ) + + def _compute_period_offset( + self, + first_day_dow: int | None, + aggregation_period: int, + period_end_dow: int | None, + ) -> int: + """ + Compute the number of leading daily timepoints to trim so + that the daily axis aligns to whole aggregation periods. + + Parameters + ---------- + first_day_dow + Day-of-week index of element 0 of the daily axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation_period == 7``. + aggregation_period + Width of the reporting period. Must be in ``{1, 7}``. + period_end_dow + Day-of-week index of each period's final day. Required + when ``aggregation_period == 7``. + + Returns + ------- + int + Trim offset in ``[0, aggregation_period)``. Returns + ``0`` when ``aggregation_period == 1``. + + Raises + ------ + ValueError + If ``aggregation_period`` is not in ``{1, 7}``, or if + ``first_day_dow`` or ``period_end_dow`` is ``None`` when + ``aggregation_period == 7``. + + Notes + ----- + For ``aggregation_period == 7``: + ``(period_end_dow + 1 - first_day_dow) % 7``. + """ + if aggregation_period == 1: + return 0 + if aggregation_period != 7: + raise ValueError( + f"aggregation_period must be one of {{1, 7}}, got {aggregation_period}" + ) + if first_day_dow is None or period_end_dow is None: + raise ValueError( + "first_day_dow and period_end_dow are both required " + "when aggregation_period == 7" + ) + return (period_end_dow + 1 - first_day_dow) % aggregation_period + def _convolve_with_alignment( self, latent_incidence: ArrayLike, @@ -363,6 +470,7 @@ def _validate_index_array( Validate an index array has non-negative values within bounds. Checks that all values are non-negative integers in ``[0, upper_bound)``. + An empty array is a no-op and passes validation. Parameters ---------- @@ -379,6 +487,8 @@ def _validate_index_array( If indices contains negative values or values >= upper_bound. """ indices = jnp.asarray(indices) + if indices.size == 0: + return if jnp.any(indices < 0): raise ValueError( f"Observation '{self.name}': {param_name} cannot be negative" @@ -484,6 +594,58 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: f"Pad with NaN for initialization period." ) + def _validate_period_end_times( + self, + period_end_times: ArrayLike, + n_total: int, + offset: int, + aggregation_period: int, + ) -> None: + """ + Validate a period-end-time index array. + + Checks that all values are non-negative, within + ``[0, n_total)``, and lie on aggregation-period boundaries, + i.e., ``(t - offset) % aggregation_period == + aggregation_period - 1`` for every entry ``t``. When + ``aggregation_period == 1`` the alignment condition holds + trivially and only the bounds check runs. + + Parameters + ---------- + period_end_times + Daily-axis indices of each observed period's final day. + n_total + Total number of time steps. + offset + Front-trim offset in daily units, as returned by + ``_compute_period_offset``. Must be in + ``[0, aggregation_period)``. + aggregation_period + Width of the reporting period in fundamental time units. + + Raises + ------ + ValueError + If ``period_end_times`` contains negative values, values + ``>= n_total``, or entries that fail the alignment check. + """ + self._validate_index_array(period_end_times, n_total, "period_end_times") + if aggregation_period == 1: + return + period_end_times = jnp.asarray(period_end_times) + misaligned = (period_end_times - offset) % aggregation_period != ( + aggregation_period - 1 + ) + if jnp.any(misaligned): + raise ValueError( + f"Observation '{self.name}': period_end_times must lie on " + f"aggregation-period boundaries " + f"(offset={offset}, aggregation_period={aggregation_period}); " + f"each entry t must satisfy " + f"(t - offset) % {aggregation_period} == {aggregation_period - 1}." + ) + @abstractmethod def sample( self, diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 563285c2..9519942e 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -14,7 +14,7 @@ from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise from pyrenew.observation.types import ObservationSample -from pyrenew.time import get_sequential_day_of_week_indices +from pyrenew.time import daily_to_weekly, get_sequential_day_of_week_indices class CountObservation(BaseObservationProcess): @@ -25,6 +25,8 @@ class CountObservation(BaseObservationProcess): with composable noise model. """ + _SUPPORTED_SCHEDULES = ("regular", "irregular") + def __init__( self, name: str, @@ -33,6 +35,9 @@ def __init__( noise: CountNoise, right_truncation_rv: RandomVariable | None = None, day_of_week_rv: RandomVariable | None = None, + aggregation_period: int = 1, + reporting_schedule: str = "regular", + period_end_dow: int | None = None, ) -> None: """ Initialize count observation base. @@ -63,12 +68,46 @@ def __init__( overall predicted counts. When provided (along with ``first_day_dow`` at sample time), predicted counts are scaled by a periodic weekly pattern. + aggregation_period + Width of the reporting period in fundamental time units. + Must be in ``{1, 7}``. ``1`` means no aggregation (daily + observations). + reporting_schedule + Either ``"regular"`` (dense observation array, one entry + per period, NaN for unobserved periods) or ``"irregular"`` + (sparse observation array with user-supplied period-end + time indices). + period_end_dow + Day-of-week index of each period's final day + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation_period == 7``; ignored otherwise. + + Raises + ------ + ValueError + If aggregation/reporting parameters are invalid, or if a + day-of-week effect is combined with ``aggregation_period > 1`` + (within-period structure is destroyed by aggregation). """ super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv) self.ascertainment_rate_rv = ascertainment_rate_rv self.noise = noise self.right_truncation_rv = right_truncation_rv self.day_of_week_rv = day_of_week_rv + self._validate_aggregation_params(aggregation_period, period_end_dow) + if reporting_schedule not in self._SUPPORTED_SCHEDULES: + raise ValueError( + f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, " + f"got {reporting_schedule!r}" + ) + if aggregation_period > 1 and day_of_week_rv is not None: + raise ValueError( + "day_of_week_rv cannot be combined with aggregation_period > 1; " + "aggregation destroys within-period structure." + ) + self.aggregation_period = aggregation_period + self.reporting_schedule = reporting_schedule + self.period_end_dow = period_end_dow def validate(self) -> None: """ @@ -244,6 +283,52 @@ def _apply_day_of_week( daily_effect = daily_effect[:, None] return predicted * daily_effect + def _aggregate( + self, + predicted_daily: ArrayLike, + first_day_dow: int | None, + ) -> ArrayLike: + """ + Aggregate daily predicted counts to the reporting-period grid. + + When ``aggregation_period == 1`` returns the input unchanged. + Otherwise sums daily values over non-overlapping fixed-width + periods anchored by ``period_end_dow``, via + ``pyrenew.time.daily_to_weekly``. Works on both 1D + ``(n_total,)`` and 2D ``(n_total, n_subpops)`` inputs. + + Parameters + ---------- + predicted_daily + Predicted counts on the daily time axis. + first_day_dow + Day-of-week index of element 0 of the daily axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation_period > 1``. + + Returns + ------- + ArrayLike + Aggregated counts on the period grid; same trailing + dimensions as ``predicted_daily``. Returns + ``predicted_daily`` unchanged when + ``aggregation_period == 1``. + + Raises + ------ + ValueError + If ``aggregation_period > 1`` and ``first_day_dow`` is ``None``. + """ + if self.aggregation_period == 1: + return predicted_daily + if first_day_dow is None: + raise ValueError("first_day_dow is required when aggregation_period > 1") + return daily_to_weekly( + predicted_daily, + input_data_first_dow=first_day_dow, + week_start_dow=(self.period_end_dow + 1) % 7, + ) + class PopulationCounts(CountObservation): """ @@ -293,6 +378,8 @@ def validate_data( n_total: int, n_subpops: int, obs: ArrayLike | None = None, + period_end_times: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> None: """ @@ -301,21 +388,66 @@ def validate_data( Parameters ---------- n_total - Total number of time steps (n_init + n_days_post_init). + Total number of daily time steps (``n_init + n_days_post_init``). n_subpops Number of subpopulations (unused for aggregate observations). obs - Observed counts on shared time axis. Shape: (n_total,). + Observed counts. Shape depends on ``reporting_schedule``: + ``"regular"`` expects a dense array of length ``n_total // P`` + after front-trim, with NaN for unobserved periods; + ``"irregular"`` expects an array matching ``period_end_times``. + period_end_times + Daily-axis indices of each observed period's final day. Required + for ``reporting_schedule="irregular"``. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation_period > 1``. **kwargs Additional keyword arguments (ignored). Raises ------ ValueError - If obs length doesn't match n_total. + If obs length or period_end_times fail their respective + checks, or if ``first_day_dow`` is missing when + ``aggregation_period > 1``. """ + P = self.aggregation_period + + if self.reporting_schedule == "regular": + if obs is None: + return + if P == 1: + self._validate_obs_dense(obs, n_total) + return + if first_day_dow is None: + raise ValueError( + f"Observation '{self.name}': first_day_dow is required " + f"when aggregation_period == {P}" + ) + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + n_periods = (n_total - offset) // P + obs = jnp.asarray(obs) + if obs.shape[0] != n_periods: + raise ValueError( + f"Observation '{self.name}': obs length {obs.shape[0]} " + f"must equal n_periods ({n_periods}). " + f"Pad with NaN for unobserved periods." + ) + return + + if period_end_times is None: + return + if P > 1 and first_day_dow is None: + raise ValueError( + f"Observation '{self.name}': first_day_dow is required " + f"when aggregation_period == {P}" + ) + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + self._validate_period_end_times(period_end_times, n_total, offset, P) if obs is not None: - self._validate_obs_dense(obs, n_total) + self._validate_obs_times_shape(obs, period_end_times) def sample( self, @@ -323,79 +455,110 @@ def sample( obs: ArrayLike | None = None, right_truncation_offset: int | None = None, first_day_dow: int | None = None, + period_end_times: ArrayLike | None = None, ) -> ObservationSample: """ Sample aggregated counts. - Both infections and obs use a shared time axis [0, n_total) where - n_total = n_init + n_days. NaN in obs marks unobserved timepoints - (initialization period or missing data). + Daily transforms (right-truncation, day-of-week) run on the + daily axis. When ``aggregation_period > 1`` the daily + predictions are summed onto the reporting-period grid before + the noise model. Likelihood path depends on + ``reporting_schedule``: ``"regular"`` uses a dense-with-NaN + array plus a mask; ``"irregular"`` fancy-indexes the + aggregated array at period indices derived from + ``period_end_times``. Parameters ---------- infections Aggregate infections from the infection process. - Shape: (n_total,) where n_total = n_init + n_days. + Shape ``(n_total,)``. obs - Observed counts on shared time axis. Shape: (n_total,). - Use NaN for initialization period and any missing observations. - None for prior predictive sampling. + Observed counts. Shape depends on ``reporting_schedule``: + ``"regular"`` expects a dense array on the period grid + with NaN for unobserved periods; ``"irregular"`` expects + an array of the same length as ``period_end_times``. + ``None`` for prior predictive sampling. right_truncation_offset - If provided (and ``right_truncation_rv`` was set at construction), - apply right-truncation adjustment to predicted counts. - first_day_dow : int | None - Day of the week for the first timepoint on the shared time - axis (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Use - ``model.compute_first_day_dow(obs_start_dow)`` to convert - from the day of the week of the first observation. + If provided (and ``right_truncation_rv`` was set at + construction), apply right-truncation adjustment to the + daily predictions. + first_day_dow + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation_period > 1``. + period_end_times + Daily-axis indices of each observed period's final day. + Required when ``reporting_schedule == "irregular"``. Returns ------- ObservationSample - Named tuple with `observed` (sampled/conditioned counts) and - `predicted` (predicted counts before noise, shape: n_total). + Named tuple with ``observed`` (sampled/conditioned counts) + and ``predicted`` (predictions on the reporting-period + grid; equal to daily predictions when + ``aggregation_period == 1``). """ - predicted_counts = self._predicted_obs(infections) + predicted_daily = self._predicted_obs(infections) if self.day_of_week_rv is not None: if first_day_dow is None: raise ValueError( "first_day_dow is required when day_of_week_rv is set." ) - predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) + predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_counts = self._apply_right_truncation( - predicted_counts, right_truncation_offset + predicted_daily = self._apply_right_truncation( + predicted_daily, right_truncation_offset ) - self._deterministic("predicted", predicted_counts) - # Compute mask: True where observation contributes to likelihood. - # NaN in predictions (initialization period) or obs (missing data) - # are excluded via mask. - valid_pred = ~jnp.isnan(predicted_counts) - if obs is not None: - valid_obs = ~jnp.isnan(obs) - mask = valid_pred & valid_obs - else: - mask = valid_pred - - # JAX evaluates log_prob for all array elements even when mask - # excludes them from the likelihood sum. Replace NaN with safe values - # to avoid NaN propagation in JAX's computation graph. These values - # do not affect inference since mask=False excludes them. - safe_predicted = jnp.where(jnp.isnan(predicted_counts), 1.0, predicted_counts) - safe_obs = None - if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + predicted = self._aggregate(predicted_daily, first_day_dow) + if self.aggregation_period > 1: + self._deterministic("predicted_daily", predicted_daily) + self._deterministic("predicted", predicted) - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=safe_predicted, - obs=safe_obs, - mask=mask, - ) + if self.reporting_schedule == "regular": + valid_pred = ~jnp.isnan(predicted) + if obs is not None: + valid_obs = ~jnp.isnan(obs) + mask = valid_pred & valid_obs + else: + mask = valid_pred + + # JAX evaluates log_prob for all elements even when mask excludes + # them from the likelihood sum. Replace NaN with safe values to + # avoid NaN propagation; mask=False ensures they do not affect + # inference. + safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) + safe_obs = None + if obs is not None: + safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=safe_predicted, + obs=safe_obs, + mask=mask, + ) + else: + if period_end_times is None: + raise ValueError( + f"Observation '{self.name}': period_end_times is " + f"required when reporting_schedule == 'irregular'" + ) + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + period_idx = (jnp.asarray(period_end_times) - offset - (P - 1)) // P + predicted_obs = predicted[period_idx] + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted_obs, + obs=obs, + ) - return ObservationSample(observed=observed, predicted=predicted_counts) + return ObservationSample(observed=observed, predicted=predicted) class SubpopulationCounts(CountObservation): @@ -444,9 +607,10 @@ def validate_data( self, n_total: int, n_subpops: int, - times: ArrayLike | None = None, - subpop_indices: ArrayLike | None = None, obs: ArrayLike | None = None, + period_end_times: ArrayLike | None = None, + first_day_dow: int | None = None, + subpop_indices: ArrayLike | None = None, **kwargs: object, ) -> None: """ @@ -455,96 +619,213 @@ def validate_data( Parameters ---------- n_total - Total number of time steps (n_init + n_days_post_init). + Total number of daily time steps (``n_init + n_days_post_init``). n_subpops Number of subpopulations. - times - Day index for each observation on the shared time axis. - subpop_indices - Subpopulation index for each observation (0-indexed). obs - Observed counts (n_obs,). + Observed counts. For ``reporting_schedule="regular"`` + has shape ``(n_periods, n_observed_subpops)`` with NaN + for unobserved periods. For + ``reporting_schedule="irregular"`` has shape ``(n_obs,)`` + matching ``period_end_times`` and ``subpop_indices``. + period_end_times + Daily-axis indices of each observed period's final day. + Required for ``reporting_schedule="irregular"``. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation_period > 1``. + subpop_indices + Subpopulation indices (0-indexed). For + ``reporting_schedule="regular"``: shape + ``(n_observed_subpops,)`` selecting which subpopulation + columns appear in ``obs``. For + ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + with one subpopulation per observation. **kwargs Additional keyword arguments (ignored). Raises ------ ValueError - If times or subpop_indices are out of bounds, or if - obs and times have mismatched lengths. + If any index array is out of bounds, any shape check + fails, or ``first_day_dow`` is missing when + ``aggregation_period > 1``. """ - if times is not None: - self._validate_times(times, n_total) - if obs is not None: - self._validate_obs_times_shape(obs, times) + P = self.aggregation_period + if subpop_indices is not None: self._validate_subpop_indices(subpop_indices, n_subpops) + if self.reporting_schedule == "regular": + if obs is None: + return + if P > 1 and first_day_dow is None: + raise ValueError( + f"Observation '{self.name}': first_day_dow is required " + f"when aggregation_period == {P}" + ) + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + n_periods = n_total if P == 1 else (n_total - offset) // P + obs = jnp.asarray(obs) + if obs.ndim != 2: + raise ValueError( + f"Observation '{self.name}': regular-schedule obs must " + f"be 2D (n_periods, n_observed_subpops); got shape {obs.shape}" + ) + if obs.shape[0] != n_periods: + raise ValueError( + f"Observation '{self.name}': obs dimension 0 length " + f"{obs.shape[0]} must equal n_periods ({n_periods}). " + f"Pad with NaN for unobserved periods." + ) + if subpop_indices is not None: + n_observed = jnp.asarray(subpop_indices).shape[0] + if obs.shape[1] != n_observed: + raise ValueError( + f"Observation '{self.name}': obs dimension 1 length " + f"{obs.shape[1]} must equal len(subpop_indices) " + f"({n_observed})" + ) + return + + if period_end_times is None: + return + if P > 1 and first_day_dow is None: + raise ValueError( + f"Observation '{self.name}': first_day_dow is required " + f"when aggregation_period == {P}" + ) + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + self._validate_period_end_times(period_end_times, n_total, offset, P) + if obs is not None: + self._validate_obs_times_shape(obs, period_end_times) + if subpop_indices is not None: + self._validate_obs_times_shape(subpop_indices, period_end_times) + def sample( self, infections: ArrayLike, - times: ArrayLike, - subpop_indices: ArrayLike, obs: ArrayLike | None = None, right_truncation_offset: int | None = None, first_day_dow: int | None = None, + period_end_times: ArrayLike | None = None, + subpop_indices: ArrayLike | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts. - Times are on the shared time axis [0, n_total) where - n_total = n_init + n_days. This method performs direct indexing - without any offset adjustment. + Daily transforms (right-truncation, day-of-week) run on the + daily axis. When ``aggregation_period > 1`` the daily + predictions are summed onto the reporting-period grid before + the noise model. Likelihood path depends on + ``reporting_schedule``: ``"regular"`` selects the observed + subpopulation columns and uses a dense-with-NaN array plus + a mask; ``"irregular"`` fancy-indexes the aggregated array + at period indices derived from ``period_end_times``. Parameters ---------- infections Subpopulation-level infections from the infection process. - Shape: (n_total, n_subpops) - times - Day index for each observation on the shared time axis. - Must be in range [0, n_total). Shape: (n_obs,) - subpop_indices - Subpopulation index for each observation (0-indexed). - Shape: (n_obs,) + Shape ``(n_total, n_subpops)``. obs - Observed counts (n_obs,), or None for prior sampling. + Observed counts. For ``reporting_schedule="regular"``: + shape ``(n_periods, n_observed_subpops)`` with NaN for + unobserved periods. For + ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + matching ``period_end_times`` and ``subpop_indices``. + ``None`` for prior predictive sampling. right_truncation_offset - If provided (and ``right_truncation_rv`` was set at construction), - apply right-truncation adjustment to predicted counts. - first_day_dow : int | None - Day of the week for the first timepoint on the shared time - axis (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Use - ``model.compute_first_day_dow(obs_start_dow)`` to convert - from the day of the week of the first observation. + If provided (and ``right_truncation_rv`` was set at + construction), apply right-truncation adjustment to the + daily predictions. + first_day_dow + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation_period > 1``. + period_end_times + Daily-axis indices of each observed period's final day. + Required when ``reporting_schedule == "irregular"``. + subpop_indices + Subpopulation indices (0-indexed). Required. For + ``reporting_schedule="regular"``: shape + ``(n_observed_subpops,)`` selecting which subpopulation + columns of the aggregated array enter the likelihood. + For ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + with one subpopulation per observation. Returns ------- ObservationSample - Named tuple with `observed` (sampled/conditioned counts) and - `predicted` (predicted counts before noise, shape: n_total x n_subpops). + Named tuple with ``observed`` (sampled/conditioned counts) + and ``predicted`` (predictions on the reporting-period + grid, shape ``(n_periods, n_subpops)``; equal to daily + predictions when ``aggregation_period == 1``). """ - predicted_counts = self._predicted_obs(infections) + if subpop_indices is None: + raise ValueError(f"Observation '{self.name}': subpop_indices is required.") + + predicted_daily = self._predicted_obs(infections) if self.day_of_week_rv is not None: if first_day_dow is None: raise ValueError( "first_day_dow is required when day_of_week_rv is set." ) - predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) + predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_counts = self._apply_right_truncation( - predicted_counts, right_truncation_offset + predicted_daily = self._apply_right_truncation( + predicted_daily, right_truncation_offset ) - self._deterministic("predicted", predicted_counts) - # Direct indexing on shared time axis - no offset needed - predicted_obs = predicted_counts[times, subpop_indices] + predicted = self._aggregate(predicted_daily, first_day_dow) + if self.aggregation_period > 1: + self._deterministic("predicted_daily", predicted_daily) + self._deterministic("predicted", predicted) - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=predicted_obs, - obs=obs, - ) + if self.reporting_schedule == "regular": + predicted_selected = predicted[:, subpop_indices] + + valid_pred = ~jnp.isnan(predicted_selected) + if obs is not None: + valid_obs = ~jnp.isnan(obs) + mask = valid_pred & valid_obs + else: + mask = valid_pred + + # JAX evaluates log_prob for all elements even when mask excludes + # them from the likelihood sum. Replace NaN with safe values to + # avoid NaN propagation; mask=False ensures they do not affect + # inference. + safe_predicted = jnp.where( + jnp.isnan(predicted_selected), 1.0, predicted_selected + ) + safe_obs = None + if obs is not None: + safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=safe_predicted, + obs=safe_obs, + mask=mask, + ) + else: + if period_end_times is None: + raise ValueError( + f"Observation '{self.name}': period_end_times is " + f"required when reporting_schedule == 'irregular'" + ) + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) + period_idx = (jnp.asarray(period_end_times) - offset - (P - 1)) // P + predicted_obs = predicted[period_idx, subpop_indices] + + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted_obs, + obs=obs, + ) - return ObservationSample(observed=observed, predicted=predicted_counts) + return ObservationSample(observed=observed, predicted=predicted) diff --git a/test/conftest.py b/test/conftest.py index c6dfb4c3..27f54f8f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -19,7 +19,9 @@ from pyrenew.observation import ( HierarchicalNormalNoise, NegativeBinomialNoise, + PoissonNoise, PopulationCounts, + SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable @@ -260,6 +262,186 @@ def counts_factory(): return CountsProcessFactory() +@pytest.fixture +def weekly_regular_counts(simple_delay_pmf): + """ + PopulationCounts with weekly aggregation and regular (dense) reporting. + + Reporting periods end on Saturdays (``period_end_dow=5``), matching the + MMWR epiweek convention. + + Returns + ------- + PopulationCounts + A weekly-regular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=5, + ) + + +@pytest.fixture +def weekly_irregular_counts(simple_delay_pmf): + """ + PopulationCounts with weekly aggregation and irregular (sparse) reporting. + + Reporting periods end on Saturdays (``period_end_dow=5``), matching the + MMWR epiweek convention. + + Returns + ------- + PopulationCounts + A weekly-irregular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="irregular", + period_end_dow=5, + ) + + +@pytest.fixture +def daily_irregular_counts(simple_delay_pmf): + """ + PopulationCounts with daily scale and irregular (sparse) reporting. + + Returns + ------- + PopulationCounts + A daily-irregular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + reporting_schedule="irregular", + ) + + +@pytest.fixture +def weekly_regular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with weekly aggregation and regular (dense) reporting. + + Reporting periods end on Saturdays (``period_end_dow=5``), matching the + MMWR epiweek convention. + + Returns + ------- + SubpopulationCounts + A weekly-regular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=5, + ) + + +@pytest.fixture +def weekly_irregular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with weekly aggregation and irregular (sparse) reporting. + + Reporting periods end on Saturdays (``period_end_dow=5``), matching the + MMWR epiweek convention. + + Returns + ------- + SubpopulationCounts + A weekly-irregular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="irregular", + period_end_dow=5, + ) + + +@pytest.fixture +def daily_irregular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with daily scale and irregular (sparse) reporting. + + Returns + ------- + SubpopulationCounts + A daily-irregular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + reporting_schedule="irregular", + ) + + # ============================================================================= # Infection Fixtures # ============================================================================= + + +@pytest.fixture +def subpop_infections_28d(): + """ + Four-week ``(28, 3)`` infections array at 100/day/subpop. + + Used by weekly-aggregation tests on ``SubpopulationCounts``. + + Returns + ------- + jnp.ndarray + Shape ``(28, 3)`` filled with 100.0. + """ + return jnp.ones((28, 3)) * 100.0 + + +@pytest.fixture +def subpop_infections_30d(): + """ + Thirty-day ``(30, 3)`` infections array at 100/day/subpop. + + Used by daily-scale tests on ``SubpopulationCounts``. + + Returns + ------- + jnp.ndarray + Shape ``(30, 3)`` filled with 100.0. + """ + return jnp.ones((30, 3)) * 100.0 + + +@pytest.fixture +def mmwr_saturday_indices_first_three(): + """ + Daily-axis indices of the first three MMWR Saturdays for a Sunday-start axis. + + When element 0 of the daily axis is a Sunday (``first_day_dow=6``), + Saturdays fall on indices 6, 13, 20, ... + + Returns + ------- + jnp.ndarray + Shape ``(3,)`` containing ``[6, 13, 20]``. + """ + return jnp.array([6, 13, 20]) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 82bea1a3..e072c9b4 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -283,6 +283,7 @@ def test_non_contiguous_subpop_indices(self): ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), + reporting_schedule="irregular", ) # 5 subpopulations with distinct infection levels @@ -291,13 +292,13 @@ def test_non_contiguous_subpop_indices(self): for k in range(5): infections = infections.at[:, k].set((k + 1) * 100.0) - times = jnp.array([10, 10, 10]) + period_end_times = jnp.array([10, 10, 10]) subpop_indices = jnp.array([0, 2, 4]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, ) @@ -550,18 +551,19 @@ def test_counts_by_subpop_2d_broadcasting(self): delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), right_truncation_rv=DeterministicPMF("rt_delay", rt_pmf), + reporting_schedule="irregular", ) n_days = 10 n_subpops = 3 infections = jnp.ones((n_days, n_subpops)) * 100 - times = jnp.array([0, 8, 9]) + period_end_times = jnp.array([0, 8, 9]) subpop_indices = jnp.array([0, 1, 2]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, right_truncation_offset=0, @@ -739,18 +741,19 @@ def test_counts_by_subpop_2d_broadcasting(self): delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), day_of_week_rv=DeterministicVariable("dow", dow_effect), + reporting_schedule="irregular", ) n_days = 14 n_subpops = 3 infections = jnp.ones((n_days, n_subpops)) * 100 - times = jnp.array([0, 1, 7]) + period_end_times = jnp.array([0, 1, 7]) subpop_indices = jnp.array([0, 1, 2]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, first_day_dow=0, @@ -843,21 +846,595 @@ def test_counts_by_subpop_dow_without_offset_raises(self): delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), noise=PoissonNoise(), day_of_week_rv=DeterministicVariable("dow", dow_effect), + reporting_schedule="irregular", ) infections = jnp.ones((14, 2)) * 100 - times = jnp.array([0, 1]) + period_end_times = jnp.array([0, 1]) subpop_indices = jnp.array([0, 1]) with numpyro.handlers.seed(rng_seed=42): with pytest.raises(ValueError, match="first_day_dow is required"): process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, first_day_dow=None, ) +# =================================================================== +# PopulationCounts with aggregation: construction-time validation +# =================================================================== + + +class TestPopulationCountsAggregationConstruction: + """Construction-time validation for the aggregation parameters.""" + + def _make(self, simple_delay_pmf, **kwargs): + """ + Build a PopulationCounts with a stub noise model and optional overrides. + + Returns + ------- + PopulationCounts + A PopulationCounts instance using the supplied delay PMF and + any additional constructor overrides passed via ``**kwargs``. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + **kwargs, + ) + + def test_default_construction_is_daily_regular(self, simple_delay_pmf): + """Default constructor yields aggregation_period=1, reporting_schedule='regular'.""" + process = self._make(simple_delay_pmf) + assert process.aggregation_period == 1 + assert process.reporting_schedule == "regular" + assert process.period_end_dow is None + + def test_weekly_requires_period_end_dow(self, simple_delay_pmf): + """aggregation_period=7 without period_end_dow must raise.""" + with pytest.raises(ValueError, match="period_end_dow is required"): + self._make(simple_delay_pmf, aggregation_period=7) + + def test_weekly_with_saturday_anchor_constructs(self, simple_delay_pmf): + """aggregation_period=7, period_end_dow=5 (Saturday) is valid.""" + process = self._make(simple_delay_pmf, aggregation_period=7, period_end_dow=5) + assert process.aggregation_period == 7 + assert process.period_end_dow == 5 + + def test_dow_effect_with_weekly_aggregation_raises(self, simple_delay_pmf): + """day_of_week_rv cannot be combined with aggregation_period > 1.""" + with pytest.raises(ValueError, match="day_of_week_rv cannot be combined"): + self._make( + simple_delay_pmf, + aggregation_period=7, + period_end_dow=5, + day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), + ) + + def test_dow_effect_with_daily_aggregation_allowed(self, simple_delay_pmf): + """day_of_week_rv remains valid for aggregation_period=1.""" + process = self._make( + simple_delay_pmf, + day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), + ) + assert process.day_of_week_rv is not None + + def test_unknown_reporting_schedule_raises(self, simple_delay_pmf): + """reporting_schedule must be 'regular' or 'irregular'.""" + with pytest.raises(ValueError, match="reporting_schedule must be one of"): + self._make(simple_delay_pmf, reporting_schedule="sporadic") + + +# =================================================================== +# PopulationCounts with aggregation: validate_data +# =================================================================== + + +class TestPopulationCountsAggregationValidateData: + """validate_data branches for each (schedule, aggregation_period) combination.""" + + def test_weekly_regular_correct_length_passes(self, weekly_regular_counts): + """Weekly-regular obs of length n_periods passes when first_day_dow aligns.""" + obs = jnp.ones(4) * 10.0 + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=6 + ) + + def test_weekly_regular_wrong_length_raises(self, weekly_regular_counts): + """Weekly-regular obs with wrong length raises.""" + obs = jnp.ones(28) * 10.0 + with pytest.raises(ValueError, match="must equal n_periods"): + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=6 + ) + + def test_weekly_regular_missing_first_day_dow_raises(self, weekly_regular_counts): + """Weekly-regular with first_day_dow=None raises.""" + obs = jnp.ones(4) * 10.0 + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_regular_counts.validate_data(n_total=28, n_subpops=1, obs=obs) + + def test_weekly_regular_obs_none_passes(self, weekly_regular_counts): + """Weekly-regular with obs=None skips checks.""" + weekly_regular_counts.validate_data(n_total=28, n_subpops=1) + + def test_weekly_regular_honors_offset(self, weekly_regular_counts): + """Weekly-regular n_periods reflects front-trim offset.""" + obs = jnp.ones(3) * 10.0 + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=0 + ) + + def test_weekly_irregular_aligned_times_pass(self, weekly_irregular_counts): + """Weekly-irregular with Saturdays at offset 0 passes.""" + period_end_times = jnp.array([6, 13, 20]) + obs = jnp.ones(3) * 10.0 + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_weekly_irregular_misaligned_times_raise(self, weekly_irregular_counts): + """Weekly-irregular with non-Saturday period_end_times raises.""" + period_end_times = jnp.array([5, 13, 20]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_weekly_irregular_missing_first_day_dow_raises( + self, weekly_irregular_counts + ): + """Weekly-irregular with first_day_dow=None raises.""" + period_end_times = jnp.array([6, 13, 20]) + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_irregular_counts.validate_data( + n_total=28, n_subpops=1, period_end_times=period_end_times + ) + + def test_weekly_irregular_obs_shape_mismatch_raises(self, weekly_irregular_counts): + """Weekly-irregular obs of wrong length raises.""" + period_end_times = jnp.array([6, 13, 20]) + obs = jnp.ones(2) * 10.0 + with pytest.raises(ValueError, match="must match"): + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_daily_irregular_passes(self, daily_irregular_counts): + """Daily-irregular validates via bounds only; alignment is trivial.""" + period_end_times = jnp.array([0, 5, 19]) + obs = jnp.ones(3) * 10.0 + daily_irregular_counts.validate_data( + n_total=20, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + ) + + def test_daily_irregular_out_of_bounds_raises(self, daily_irregular_counts): + """Daily-irregular out-of-bounds index raises.""" + period_end_times = jnp.array([0, 5, 25]) + with pytest.raises(ValueError, match="upper bound"): + daily_irregular_counts.validate_data( + n_total=20, n_subpops=1, period_end_times=period_end_times + ) + + +# =================================================================== +# PopulationCounts with aggregation: sample +# =================================================================== + + +class TestPopulationCountsAggregationSample: + """Sample-time behavior for the new aggregation paths.""" + + def test_weekly_regular_predicted_shape(self, weekly_regular_counts): + """Weekly-regular predicted has shape (n_periods,) equal to weekly sums.""" + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_counts.sample( + infections=infections, first_day_dow=6 + ) + assert result.predicted.shape == (4,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_regular_emits_predicted_daily_site(self, weekly_regular_counts): + """When aggregation_period > 1, the 'predicted_daily' deterministic site exists.""" + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_counts.sample(infections=infections, first_day_dow=6) + assert "hosp_predicted_daily" in trace + assert "hosp_predicted" in trace + assert trace["hosp_predicted_daily"]["value"].shape == (28,) + assert trace["hosp_predicted"]["value"].shape == (4,) + + def test_daily_backward_compat_no_predicted_daily_site(self, simple_delay_pmf): + """When aggregation_period == 1, 'predicted_daily' is not emitted.""" + process = PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + process.sample(infections=infections) + assert "hosp_predicted_daily" not in trace + assert "hosp_predicted" in trace + assert trace["hosp_predicted"]["value"].shape == (28,) + + def test_weekly_regular_with_obs_runs(self, weekly_regular_counts): + """Weekly-regular sample accepts dense-with-NaN obs on the period grid.""" + infections = jnp.ones(28) * 100.0 + obs = jnp.array([7.0, jnp.nan, 7.0, 7.0]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_counts.sample( + infections=infections, obs=obs, first_day_dow=6 + ) + assert result.observed.shape == (4,) + assert result.predicted.shape == (4,) + + def test_weekly_irregular_period_indexing(self, weekly_irregular_counts): + """Weekly-irregular fancy-indexes the aggregated array at correct periods.""" + infections = jnp.ones(28) * 100.0 + period_end_times = jnp.array([6, 20]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_irregular_counts.sample( + infections=infections, + period_end_times=period_end_times, + first_day_dow=6, + ) + assert result.predicted.shape == (4,) + assert result.observed.shape == (2,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_irregular_missing_period_end_times_raises( + self, weekly_irregular_counts + ): + """Irregular schedule requires period_end_times at sample time.""" + infections = jnp.ones(28) * 100.0 + with pytest.raises(ValueError, match="period_end_times is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_irregular_counts.sample(infections=infections, first_day_dow=6) + + def test_daily_irregular_period_indexing(self, daily_irregular_counts): + """Daily-irregular fancy-indexes at the supplied daily indices directly.""" + infections = jnp.ones(30) * 100.0 + period_end_times = jnp.array([5, 10, 20]) + with numpyro.handlers.seed(rng_seed=42): + result = daily_irregular_counts.sample( + infections=infections, period_end_times=period_end_times + ) + assert result.predicted.shape == (30,) + assert result.observed.shape == (3,) + + +# =================================================================== +# SubpopulationCounts with aggregation: validate_data +# =================================================================== + + +class TestSubpopulationCountsAggregationValidateData: + """validate_data branches for each (schedule, aggregation_period) combination.""" + + def test_daily_regular_valid_passes(self, simple_delay_pmf): + """Daily-regular dense 2D obs (n_total, n_observed_subpops) passes.""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + obs = jnp.ones((30, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + process.validate_data( + n_total=30, n_subpops=3, obs=obs, subpop_indices=subpop_indices + ) + + def test_weekly_regular_valid_passes(self, weekly_regular_subpop_counts): + """Weekly-regular dense 2D obs (n_periods, n_observed_subpops) passes.""" + obs = jnp.ones((4, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_regular_wrong_n_periods_raises(self, weekly_regular_subpop_counts): + """Weekly-regular obs with wrong dim-0 length raises.""" + obs = jnp.ones((28, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match="must equal n_periods"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_regular_wrong_n_subpops_raises(self, weekly_regular_subpop_counts): + """Regular-schedule obs dim-1 must equal len(subpop_indices).""" + obs = jnp.ones((4, 3)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match=r"must equal len\(subpop_indices\)"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_regular_missing_first_day_dow_raises( + self, weekly_regular_subpop_counts + ): + """Weekly-regular without first_day_dow raises.""" + obs = jnp.ones((4, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + subpop_indices=subpop_indices, + ) + + def test_regular_1d_obs_raises(self, weekly_regular_subpop_counts): + """Regular-schedule obs must be 2D.""" + obs = jnp.ones(4) * 5.0 + subpop_indices = jnp.array([0]) + with pytest.raises(ValueError, match="regular-schedule obs must be 2D"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_regular_bad_subpop_indices_raises(self, weekly_regular_subpop_counts): + """Regular-schedule out-of-bounds subpop_indices raises.""" + subpop_indices = jnp.array([0, 5]) + with pytest.raises(ValueError, match="upper bound"): + weekly_regular_subpop_counts.validate_data( + n_total=28, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_daily_irregular_valid_passes(self, daily_irregular_subpop_counts): + """Daily-irregular with valid period_end_times and subpop_indices passes.""" + period_end_times = jnp.array([5, 10, 15, 20]) + subpop_indices = jnp.array([0, 1, 2, 0]) + obs = jnp.array([10.0, 20.0, 30.0, 15.0]) + daily_irregular_subpop_counts.validate_data( + n_total=30, + n_subpops=3, + obs=obs, + period_end_times=period_end_times, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_valid_passes( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Weekly-irregular with Saturdays at offset 0 passes.""" + subpop_indices = jnp.array([0, 1, 2]) + obs = jnp.array([10.0, 20.0, 30.0]) + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_misaligned_raises(self, weekly_irregular_subpop_counts): + """Weekly-irregular with non-Saturday period_end_times raises.""" + period_end_times = jnp.array([5, 13, 20]) + subpop_indices = jnp.array([0, 1, 2]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=period_end_times, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_missing_first_day_dow_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Weekly-irregular without first_day_dow raises.""" + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=mmwr_saturday_indices_first_three, + ) + + def test_irregular_obs_shape_mismatch_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Irregular-schedule obs length must match period_end_times.""" + obs = jnp.array([10.0, 20.0]) + with pytest.raises(ValueError, match="must match"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + ) + + def test_irregular_subpop_indices_shape_mismatch_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Irregular-schedule subpop_indices length must match period_end_times.""" + subpop_indices = jnp.array([0, 1]) + with pytest.raises(ValueError, match="must match"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + +# =================================================================== +# SubpopulationCounts with aggregation: sample +# =================================================================== + + +class TestSubpopulationCountsAggregationSample: + """Sample-time behavior for the new aggregation paths.""" + + def test_daily_regular_shape(self, simple_delay_pmf, subpop_infections_30d): + """Daily-regular sample returns predicted of shape (n_total, n_subpops).""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=subpop_infections_30d, subpop_indices=subpop_indices + ) + assert result.predicted.shape == (30, 3) + assert result.observed.shape == (30, 2) + + def test_weekly_regular_predicted_shape_and_values( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """Weekly-regular aggregates (n_total, n_subpops) to (n_periods, n_subpops).""" + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + assert result.observed.shape == (4, 2) + + def test_weekly_regular_emits_predicted_daily_site( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """aggregation_period > 1 emits a 'predicted_daily' deterministic site.""" + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert "ed_predicted_daily" in trace + assert "ed_predicted" in trace + assert trace["ed_predicted_daily"]["value"].shape == (28, 3) + assert trace["ed_predicted"]["value"].shape == (4, 3) + + def test_daily_backward_compat_no_predicted_daily_site( + self, simple_delay_pmf, subpop_infections_28d + ): + """aggregation_period == 1 emits only 'predicted', not 'predicted_daily'.""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + subpop_indices = jnp.array([0, 1, 2]) + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + process.sample( + infections=subpop_infections_28d, subpop_indices=subpop_indices + ) + assert "ed_predicted_daily" not in trace + assert "ed_predicted" in trace + assert trace["ed_predicted"]["value"].shape == (28, 3) + + def test_missing_subpop_indices_raises( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """subpop_indices is required in every sample call.""" + with pytest.raises(ValueError, match="subpop_indices is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, first_day_dow=6 + ) + + def test_weekly_irregular_period_indexing( + self, weekly_irregular_subpop_counts, subpop_infections_28d + ): + """Weekly-irregular fancy-indexes the aggregated array at (period, subpop).""" + period_end_times = jnp.array([6, 20]) + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_irregular_subpop_counts.sample( + infections=subpop_infections_28d, + period_end_times=period_end_times, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert result.observed.shape == (2,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_irregular_missing_period_end_times_raises( + self, weekly_irregular_subpop_counts, subpop_infections_28d + ): + """Irregular schedule requires period_end_times at sample time.""" + subpop_indices = jnp.array([0, 1, 2]) + with pytest.raises(ValueError, match="period_end_times is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_irregular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_daily_irregular_fancy_indexing( + self, daily_irregular_subpop_counts, subpop_infections_30d + ): + """Daily-irregular indexes predicted at (period_end_times, subpop_indices).""" + period_end_times = jnp.array([5, 10, 20]) + subpop_indices = jnp.array([0, 1, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = daily_irregular_subpop_counts.sample( + infections=subpop_infections_30d, + period_end_times=period_end_times, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (30, 3) + assert result.observed.shape == (3,) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index 3632f04a..16ca3db4 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -19,6 +19,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable +from pyrenew.time import daily_to_weekly # --------------------------------------------------------------------------- # Helpers – minimal concrete subclass of MeasurementObservation for testing @@ -95,6 +96,7 @@ def subpop_proc(): ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.4, 0.4, 0.2])), noise=PoissonNoise(), + reporting_schedule="irregular", ) @@ -188,6 +190,14 @@ def test_non_contiguous_valid_indices(self, counts_proc): param_name="test_idx", ) + def test_empty_array_passes(self, counts_proc): + """Empty index array has no values to bounds-check; should not raise.""" + counts_proc._validate_index_array( + jnp.array([], dtype=jnp.int32), + upper_bound=10, + param_name="test_idx", + ) + # =================================================================== # _validate_times @@ -375,29 +385,33 @@ def test_all_none_passes(self, subpop_proc): subpop_proc.validate_data(n_total=30, n_subpops=3) def test_valid_data_passes(self, subpop_proc): - """validate_data with valid times, subpop_indices, obs should not raise.""" - times = jnp.array([5, 10, 15, 20]) + """validate_data with valid period_end_times, subpop_indices, obs should not raise.""" + period_end_times = jnp.array([5, 10, 15, 20]) subpop_indices = jnp.array([0, 1, 2, 0]) obs = jnp.array([10.0, 20.0, 30.0, 15.0]) subpop_proc.validate_data( n_total=30, n_subpops=3, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=obs, ) def test_invalid_times_raises(self, subpop_proc): - """validate_data with out-of-bounds times should raise.""" - times = jnp.array([5, 30]) # 30 == n_total, out of bounds + """validate_data with out-of-bounds period_end_times should raise.""" + period_end_times = jnp.array([5, 30]) # 30 == n_total, out of bounds with pytest.raises(ValueError, match="upper bound"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_negative_times_raises(self, subpop_proc): - """validate_data with negative times should raise.""" - times = jnp.array([-1, 5]) + """validate_data with negative period_end_times should raise.""" + period_end_times = jnp.array([-1, 5]) with pytest.raises(ValueError, match="cannot be negative"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_invalid_subpop_indices_raises(self, subpop_proc): """validate_data with out-of-bounds subpop_indices should raise.""" @@ -416,22 +430,28 @@ def test_negative_subpop_indices_raises(self, subpop_proc): ) def test_mismatched_obs_times_raises(self, subpop_proc): - """validate_data with obs/times shape mismatch should raise.""" - times = jnp.array([5, 10, 15]) + """validate_data with obs/period_end_times shape mismatch should raise.""" + period_end_times = jnp.array([5, 10, 15]) obs = jnp.array([1.0, 2.0]) # length 2 != length 3 with pytest.raises(ValueError, match="must match times shape"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times, obs=obs) + subpop_proc.validate_data( + n_total=30, + n_subpops=3, + period_end_times=period_end_times, + obs=obs, + ) def test_obs_without_times_skips_shape_check(self, subpop_proc): - """validate_data with obs but no times should not check shapes.""" + """validate_data with obs but no period_end_times should not check shapes.""" obs = jnp.array([1.0, 2.0]) - # times is None, so shape check is skipped subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) def test_times_without_obs_skips_shape_check(self, subpop_proc): - """validate_data with times but no obs should validate times only.""" - times = jnp.array([5, 10, 15]) - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + """validate_data with period_end_times but no obs should validate indices only.""" + period_end_times = jnp.array([5, 10, 15]) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_non_contiguous_subpop_indices_valid(self, subpop_proc): """validate_data with non-contiguous but valid subpop_indices passes.""" @@ -556,3 +576,202 @@ def test_non_contiguous_subpop_indices_valid(self, measurements_proc): def test_extra_kwargs_ignored(self, measurements_proc): """validate_data should ignore extra keyword arguments.""" measurements_proc.validate_data(n_total=30, n_subpops=3, foo="bar") + + +# =================================================================== +# _validate_aggregation_params +# =================================================================== + + +class TestValidateAggregationParams: + """Tests for BaseObservationProcess._validate_aggregation_params.""" + + def test_p1_no_anchor_passes(self, counts_proc): + """P=1 with period_end_dow=None should not raise.""" + counts_proc._validate_aggregation_params(1, None) + + def test_p1_anchor_ignored(self, counts_proc): + """P=1 with any period_end_dow value should not raise.""" + counts_proc._validate_aggregation_params(1, 5) + counts_proc._validate_aggregation_params(1, 99) + + def test_p7_valid_anchor_passes(self, counts_proc): + """P=7 with period_end_dow in {0, ..., 6} should not raise.""" + for dow in range(7): + counts_proc._validate_aggregation_params(7, dow) + + def test_p7_missing_anchor_raises(self, counts_proc): + """P=7 with period_end_dow=None should raise.""" + with pytest.raises(ValueError, match="period_end_dow is required"): + counts_proc._validate_aggregation_params(7, None) + + def test_p7_negative_anchor_raises(self, counts_proc): + """P=7 with period_end_dow<0 should raise.""" + with pytest.raises(ValueError, match="integer in"): + counts_proc._validate_aggregation_params(7, -1) + + def test_p7_too_large_anchor_raises(self, counts_proc): + """P=7 with period_end_dow>6 should raise.""" + with pytest.raises(ValueError, match="integer in"): + counts_proc._validate_aggregation_params(7, 7) + + def test_zero_period_raises(self, counts_proc): + """aggregation_period=0 should raise.""" + with pytest.raises(ValueError, match="positive integer"): + counts_proc._validate_aggregation_params(0, None) + + def test_negative_period_raises(self, counts_proc): + """Negative aggregation_period should raise.""" + with pytest.raises(ValueError, match="positive integer"): + counts_proc._validate_aggregation_params(-3, None) + + def test_unsupported_period_raises(self, counts_proc): + """aggregation_period not in {1, 7} should raise.""" + with pytest.raises(ValueError, match=r"one of \{1, 7\}"): + counts_proc._validate_aggregation_params(14, 5) + + def test_float_period_raises(self, counts_proc): + """Non-integer aggregation_period should raise.""" + with pytest.raises(ValueError, match="positive integer"): + counts_proc._validate_aggregation_params(7.0, 5) + + +# =================================================================== +# _compute_period_offset +# =================================================================== + + +class TestComputePeriodOffset: + """Tests for BaseObservationProcess._compute_period_offset.""" + + def test_p1_returns_zero(self, counts_proc): + """P=1 always returns 0 regardless of dow arguments.""" + assert counts_proc._compute_period_offset(None, 1, None) == 0 + assert counts_proc._compute_period_offset(0, 1, 5) == 0 + assert counts_proc._compute_period_offset(6, 1, None) == 0 + + def test_p7_mmwr_aligned_start(self, counts_proc): + """Daily axis starting Sunday with Saturday end => offset 0.""" + assert counts_proc._compute_period_offset(6, 7, 5) == 0 + + def test_p7_monday_start_saturday_end(self, counts_proc): + """Daily axis starting Monday with Saturday end => offset 6.""" + assert counts_proc._compute_period_offset(0, 7, 5) == 6 + + def test_p7_saturday_start_saturday_end(self, counts_proc): + """Daily axis starting Saturday with Saturday end => offset 1.""" + assert counts_proc._compute_period_offset(5, 7, 5) == 1 + + def test_p7_iso_week_alignment(self, counts_proc): + """Daily axis starting Thursday with Sunday end (ISO) => offset 4.""" + assert counts_proc._compute_period_offset(3, 7, 6) == 4 + + def test_p7_offset_always_in_range(self, counts_proc): + """P=7 offset is always in [0, 7) for any valid dow combination.""" + for first in range(7): + for end in range(7): + offset = counts_proc._compute_period_offset(first, 7, end) + assert 0 <= offset < 7 + + def test_p7_missing_first_day_dow_raises(self, counts_proc): + """P=7 with first_day_dow=None should raise.""" + with pytest.raises(ValueError, match="both required"): + counts_proc._compute_period_offset(None, 7, 5) + + def test_p7_missing_period_end_dow_raises(self, counts_proc): + """P=7 with period_end_dow=None should raise.""" + with pytest.raises(ValueError, match="both required"): + counts_proc._compute_period_offset(0, 7, None) + + def test_unsupported_period_raises(self, counts_proc): + """P not in {1, 7} should raise.""" + with pytest.raises(ValueError, match=r"one of \{1, 7\}"): + counts_proc._compute_period_offset(0, 14, 5) + + def test_offset_agrees_with_daily_to_weekly(self, counts_proc): + """ + Offset from _compute_period_offset selects the same leading + days that daily_to_weekly trims internally for the first + complete period, for every (first_day_dow, period_end_dow). + """ + daily = jnp.arange(21.0) + for first in range(7): + for end in range(7): + offset = counts_proc._compute_period_offset(first, 7, end) + weekly = daily_to_weekly( + daily, + input_data_first_dow=first, + week_start_dow=(end + 1) % 7, + ) + expected_first_week = float(jnp.sum(daily[offset : offset + 7])) + assert float(weekly[0]) == expected_first_week + + +# =================================================================== +# _validate_period_end_times +# =================================================================== + + +class TestValidatePeriodEndTimes: + """Tests for BaseObservationProcess._validate_period_end_times.""" + + def test_p7_aligned_times_pass(self, counts_proc): + """P=7 with Saturdays at offset 0 should not raise.""" + times = jnp.array([6, 13, 20]) + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_p7_aligned_nonzero_offset_pass(self, counts_proc): + """P=7 with nonzero offset shifts the boundary days.""" + times = jnp.array([12, 19, 26]) + counts_proc._validate_period_end_times( + times, n_total=30, offset=6, aggregation_period=7 + ) + + def test_p1_any_in_bounds_passes(self, counts_proc): + """P=1: alignment is trivial; any in-bounds index passes.""" + times = jnp.array([0, 3, 7, 19]) + counts_proc._validate_period_end_times( + times, n_total=20, offset=0, aggregation_period=1 + ) + + def test_misaligned_time_raises(self, counts_proc): + """P=7 with a non-boundary time should raise.""" + times = jnp.array([5]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_partial_misalignment_raises(self, counts_proc): + """Any single misaligned entry should raise.""" + times = jnp.array([6, 12, 20]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_negative_time_raises(self, counts_proc): + """Negative period_end_times should raise via bounds check.""" + times = jnp.array([-1, 6]) + with pytest.raises(ValueError, match="cannot be negative"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_time_at_n_total_raises(self, counts_proc): + """period_end_times == n_total should raise via bounds check.""" + times = jnp.array([21]) + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_error_reports_offset_and_period(self, counts_proc): + """Alignment error message should include offset and aggregation_period.""" + times = jnp.array([5]) + with pytest.raises(ValueError, match=r"offset=0.*aggregation_period=7"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index ff60a309..04f5da27 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -93,6 +93,7 @@ def validation_builder(): ascertainment_rate_rv=DeterministicVariable("ihr_subpop", 0.01), delay_distribution_rv=delay, noise=NegativeBinomialNoise(DeterministicVariable("conc_subpop", 10.0)), + reporting_schedule="irregular", ) ) @@ -284,7 +285,7 @@ def test_validate_data_accepts_valid_data(self, validation_builder): }, hospital_subpop={ "obs": jnp.array([10, 20]), - "times": jnp.array([5, 10]), + "period_end_times": jnp.array([5, 10]), }, ) @@ -299,7 +300,7 @@ def test_validate_data_rejects_out_of_bounds_times(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10]), - "times": jnp.array([n_total + 10]), + "period_end_times": jnp.array([n_total + 10]), }, ) @@ -313,7 +314,7 @@ def test_validate_data_rejects_negative_times(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10]), - "times": jnp.array([-1]), + "period_end_times": jnp.array([-1]), }, ) @@ -327,7 +328,7 @@ def test_validate_data_rejects_unknown_observation(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, unknown_obs={ "obs": jnp.array([10]), - "times": jnp.array([5]), + "period_end_times": jnp.array([5]), }, ) @@ -343,7 +344,7 @@ def test_validate_data_rejects_mismatched_obs_times_length( subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10, 20, 30]), # 3 elements - "times": jnp.array([5, 10]), # 2 elements + "period_end_times": jnp.array([5, 10]), # 2 elements }, ) @@ -363,7 +364,7 @@ def test_validate_data_rejects_negative_subpop_indices(self, validation_builder) subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "subpop_indices": jnp.array([-1, 0, 1]), - "times": jnp.array([5, 6, 7]), + "period_end_times": jnp.array([5, 6, 7]), }, ) @@ -380,7 +381,7 @@ def test_validate_data_rejects_out_of_bounds_subpop_indices( subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "subpop_indices": jnp.array([0, 1, 5]), # 5 >= 3 - "times": jnp.array([5, 6, 7]), + "period_end_times": jnp.array([5, 6, 7]), }, ) From 40658426151161076baa3b82c1cde345babf36a4 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 21 Apr 2026 18:14:46 -0400 Subject: [PATCH 02/17] add stepwise temporal process and unit tests --- pyrenew/latent/__init__.py | 2 + pyrenew/latent/temporal_processes.py | 115 +++++++++++++++++++++++++++ test/test_temporal_processes.py | 95 +++++++++++++++++++++- 3 files changed, 211 insertions(+), 1 deletion(-) diff --git a/pyrenew/latent/__init__.py b/pyrenew/latent/__init__.py index 45f6f06b..aecd5db1 100644 --- a/pyrenew/latent/__init__.py +++ b/pyrenew/latent/__init__.py @@ -32,6 +32,7 @@ AR1, DifferencedAR1, RandomWalk, + StepwiseTemporalProcess, TemporalProcess, ) @@ -61,4 +62,5 @@ "AR1", "DifferencedAR1", "RandomWalk", + "StepwiseTemporalProcess", ] diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 4020564f..5fef2101 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -39,6 +39,10 @@ while stabilizing the growth rate. Wraps [pyrenew.process.DifferencedProcess][]. - ``RandomWalk``: No mean reversion. Rt can drift without bound. Wraps [pyrenew.process.RandomWalk][]. +- ``StepwiseTemporalProcess``: Wrapper that parameterizes any inner + ``TemporalProcess`` at a coarser cadence and broadcasts to the per-timepoint + scale by repetition. Use to match R(t) parametrization to the coarsest + observation cadence. All implementations satisfy the ``TemporalProcess`` protocol and can be used interchangeably in hierarchical infection models. @@ -66,8 +70,19 @@ class TemporalProcess(Protocol): Used for jurisdiction-level Rt dynamics, subpopulation deviations, or allocation trajectories. All processes return 2D arrays of shape (n_timepoints, n_processes) for consistent handling. + + Attributes + ---------- + step_size : int + Number of consecutive timepoints that share the same sampled value. + Defaults to ``1`` for the standard processes (one independent sample + per timepoint). Wrapper processes like ``StepwiseTemporalProcess`` + expose a larger value so that model builders can enforce coherence + between R(t) parametrization cadence and observation aggregation. """ + step_size: int + def sample( self, n_timepoints: int, @@ -120,6 +135,8 @@ class AR1(TemporalProcess): more volatile trajectories; smaller values produce smoother ones. """ + step_size: int = 1 + def __init__(self, autoreg: float, innovation_sd: float = 1.0) -> None: """ Initialize AR(1) process. @@ -220,6 +237,8 @@ class DifferencedAR1(TemporalProcess): more erratic growth rates; smaller values produce smoother trends. """ + step_size: int = 1 + def __init__(self, autoreg: float, innovation_sd: float = 1.0) -> None: """ Initialize differenced AR(1) process. @@ -331,6 +350,8 @@ class RandomWalk(TemporalProcess): (``{name_prefix}_step``) via ``numpyro.handlers.reparam``. """ + step_size: int = 1 + def __init__(self, innovation_sd: float = 1.0) -> None: """ Initialize random walk process. @@ -399,3 +420,97 @@ def sample( init_vals=initial_value[jnp.newaxis, :], n=n_timepoints, ) + + +class StepwiseTemporalProcess(TemporalProcess): + """ + Parameterize an inner temporal process at a coarser cadence and + broadcast to the per-timepoint scale by repetition. + + Each ``step_size`` consecutive output timepoints share a single + sampled value from the inner process. Use to match R(t) + parametrization cadence to the coarsest observation cadence + (e.g., ``step_size=7`` with weekly-aggregated observations). + + Parameters + ---------- + inner + Inner ``TemporalProcess`` that generates the coarse-scale + trajectory. Must satisfy the ``TemporalProcess`` protocol. + step_size + Number of per-timepoint units that share each inner sample. + Must be a positive integer. + + Raises + ------ + ValueError + If ``step_size`` is not a positive integer. + """ + + def __init__(self, inner: TemporalProcess, step_size: int) -> None: + """ + Initialize stepwise temporal process. + + Parameters + ---------- + inner + Inner ``TemporalProcess`` that generates the coarse trajectory. + step_size + Number of per-timepoint units that share each inner sample. + + Raises + ------ + ValueError + If ``step_size`` is not a positive integer. + """ + if not isinstance(step_size, int) or step_size < 1: + raise ValueError(f"step_size must be a positive integer, got {step_size!r}") + self.inner = inner + self.step_size = step_size + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"StepwiseTemporalProcess(inner={self.inner!r}, step_size={self.step_size})" + ) + + def sample( + self, + n_timepoints: int, + initial_value: float | ArrayLike | None = None, + n_processes: int = 1, + name_prefix: str = "stepwise", + ) -> ArrayLike: + """ + Sample coarse trajectory from inner process and broadcast. + + Computes ``n_steps = ceil(n_timepoints / step_size)``, samples + the inner process at that cadence, then repeats each coarse + value ``step_size`` times along the time axis and trims to + ``n_timepoints``. + + Parameters + ---------- + n_timepoints + Number of per-timepoint outputs to produce. + initial_value + Initial value(s) for the inner process. Defaults to 0.0. + n_processes + Number of parallel processes. + name_prefix + Prefix for numpyro sample sites; forwarded to the inner process. + + Returns + ------- + ArrayLike + Trajectories of shape ``(n_timepoints, n_processes)``, constant + within each block of ``step_size`` consecutive rows. + """ + n_steps = (n_timepoints + self.step_size - 1) // self.step_size + coarse = self.inner.sample( + n_timepoints=n_steps, + initial_value=initial_value, + n_processes=n_processes, + name_prefix=name_prefix, + ) + return jnp.repeat(coarse, repeats=self.step_size, axis=0)[:n_timepoints] diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 7248518e..48ebd2df 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -6,7 +6,13 @@ import numpyro import pytest -from pyrenew.latent import AR1, DifferencedAR1, RandomWalk +from pyrenew.latent import AR1, DifferencedAR1, RandomWalk, StepwiseTemporalProcess + +INNER_PROCESS_PARAMS = [ + (AR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (DifferencedAR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (RandomWalk, {"innovation_sd": 0.05}), +] class TestTemporalProcessVectorizedSampling: @@ -236,5 +242,92 @@ def test_differenced_ar1_trend_persistence(self): assert fraction_positive > 0.5 +class TestTemporalProcessStepSizeDefault: + """Standard temporal processes expose step_size=1 as a class attribute.""" + + @pytest.mark.parametrize( + "process_cls", + [AR1, DifferencedAR1, RandomWalk], + ) + def test_class_attribute_step_size_is_one(self, process_cls): + """Each concrete class has step_size=1 as a class attribute.""" + assert process_cls.step_size == 1 + + +class TestStepwiseTemporalProcessConstruction: + """Construction-time validation for StepwiseTemporalProcess.""" + + def test_step_size_attribute(self): + """step_size is exposed on the instance for builder inspection.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ) + assert wrapper.step_size == 7 + + def test_zero_step_size_raises(self): + """step_size=0 raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=0) + + def test_negative_step_size_raises(self): + """Negative step_size raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=-1) + + def test_float_step_size_raises(self): + """Non-integer step_size raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=7.0) + + +class TestStepwiseTemporalProcessSample: + """Sample-time behavior of StepwiseTemporalProcess.""" + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_output_shape_divisible(self, inner_cls, inner_kwargs): + """n_timepoints divisible by step_size yields (n_timepoints, n_processes).""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=28, n_processes=3) + assert result.shape == (28, 3) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_output_shape_non_divisible(self, inner_cls, inner_kwargs): + """n_timepoints not divisible by step_size still yields n_timepoints rows.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=30, n_processes=2) + assert result.shape == (30, 2) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_broadcast_within_block(self, inner_cls, inner_kwargs): + """Each block of step_size consecutive timepoints is constant.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=28, n_processes=1) + for start in range(0, 28, 7): + block = result[start : start + 7] + assert jnp.allclose(block, block[0]) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_broadcast_with_partial_final_block(self, inner_cls, inner_kwargs): + """A partial final block is still constant, just shorter than step_size.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=30, n_processes=1) + # The 4th block starts at row 28 and runs through row 29 (2 rows) + final_block = result[28:30] + assert jnp.allclose(final_block, final_block[0]) + + def test_step_size_one_passthrough_shape(self): + """step_size=1 yields (n_timepoints, n_processes), same as inner directly.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=1 + ) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=20, n_processes=2) + assert result.shape == (20, 2) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From f06d98d11e02877c417594e9b59581579bcfa74f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 21 Apr 2026 19:42:14 -0400 Subject: [PATCH 03/17] updated builder, added integration test for agg hosp data, all unit tests passing --- _typos.toml | 1 + ...agen_he_CA_120.py => datagen_he_CA_126.py} | 67 ++++--- .../synthetic_CA_120/daily_ed_visits.csv | 121 ------------ .../daily_hospital_admissions.csv | 121 ------------ .../synthetic_CA_120/daily_infections.csv | 121 ------------ .../weekly_hospital_admissions.csv | 17 -- .../synthetic_CA_126/daily_ed_visits.csv | 127 ++++++++++++ .../daily_hospital_admissions.csv | 127 ++++++++++++ .../synthetic_CA_126/daily_infections.csv | 127 ++++++++++++ .../true_parameters.json | 13 +- .../weekly_hospital_admissions.csv | 19 ++ pyrenew/datasets/synthetic_data.py | 10 +- pyrenew/model/multisignal_model.py | 12 +- pyrenew/model/pyrenew_builder.py | 83 +++++++- test/conftest.py | 14 ++ test/integration/conftest.py | 82 ++++++++ .../test_population_infections_he.py | 10 +- test/test_ar_process.py | 19 +- ...he_CA_120.py => test_datagen_he_CA_126.py} | 67 ++++--- test/test_datasets_synthetic.py | 24 +-- test/test_pyrenew_builder.py | 187 +++++++++++++++++- 21 files changed, 901 insertions(+), 468 deletions(-) rename pyrenew/datasets/{datagen_he_CA_120.py => datagen_he_CA_126.py} (84%) delete mode 100644 pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv delete mode 100644 pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv delete mode 100644 pyrenew/datasets/synthetic_CA_120/daily_infections.csv delete mode 100644 pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv create mode 100644 pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv create mode 100644 pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv create mode 100644 pyrenew/datasets/synthetic_CA_126/daily_infections.csv rename pyrenew/datasets/{synthetic_CA_120 => synthetic_CA_126}/true_parameters.json (86%) create mode 100644 pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv rename test/{test_datagen_he_CA_120.py => test_datagen_he_CA_126.py} (74%) diff --git a/_typos.toml b/_typos.toml index 47865732..8cf4ff9c 100644 --- a/_typos.toml +++ b/_typos.toml @@ -3,3 +3,4 @@ # words that should not be corrected arange = "arange" lod = "lod" +dows = "dows" diff --git a/pyrenew/datasets/datagen_he_CA_120.py b/pyrenew/datasets/datagen_he_CA_126.py similarity index 84% rename from pyrenew/datasets/datagen_he_CA_120.py rename to pyrenew/datasets/datagen_he_CA_126.py index 3ccfeaf9..476c8036 100644 --- a/pyrenew/datasets/datagen_he_CA_120.py +++ b/pyrenew/datasets/datagen_he_CA_126.py @@ -1,13 +1,18 @@ # numpydoc ignore=ES01,SA01,EX01 """ -Generate 120-day synthetic CA hospital + ED visit data from known R(t). +Generate 126-day synthetic CA hospital + ED visit data from known R(t). This script defines R(t) directly, runs a renewal equation forward, and convolves with signal-specific delay PMFs to produce two observation streams. All true parameters are saved alongside the synthetic observations so tutorials can demonstrate posterior recovery. -Outputs (in pyrenew/datasets/synthetic_CA_120/): +The start date (2023-11-05, Sunday) and length (126 days = 18 weeks) +are chosen so that the daily time axis aligns exactly with 18 complete +MMWR epiweeks, with no days trimmed at either end during weekly +aggregation. + +Outputs (in pyrenew/datasets/synthetic_CA_126/): - true_parameters.json - daily_infections.csv - daily_ed_visits.csv @@ -16,7 +21,7 @@ Run from repo root:: - python -m pyrenew.datasets.datagen_he_CA_120 + python -m pyrenew.datasets.datagen_he_CA_126 """ from __future__ import annotations @@ -33,7 +38,7 @@ from pyrenew.time import daily_to_mmwr_epiweekly, get_sequential_day_of_week_indices REPO_ROOT = Path(__file__).resolve().parents[2] -OUTPUT_DIR = REPO_ROOT / "pyrenew" / "datasets" / "synthetic_CA_120" +OUTPUT_DIR = REPO_ROOT / "pyrenew" / "datasets" / "synthetic_CA_126" RNG_SEED = 20240101 POPULATION = 39_512_223 @@ -123,10 +128,11 @@ IEDR = 0.0075 I0_PER_CAPITA = 5e-4 NEGBINOM_CONCENTRATION_HOSP = 350.0 +NEGBINOM_CONCENTRATION_HOSP_WEEKLY = 100.0 NEGBINOM_CONCENTRATION_ED = 50.0 DOW_EFFECTS = np.array([1.15, 1.12, 1.08, 1.05, 0.98, 0.82, 0.80]) -START_DATE = date(2023, 11, 6) +START_DATE = date(2023, 11, 5) N_INIT = max(50, len(HOSP_DELAY_PMF), len(ED_DELAY_PMF)) @@ -134,8 +140,8 @@ def build_true_rt() -> np.ndarray: """ Build a piecewise-linear true R(t) trajectory. - Phases: decline from 1.2 to 0.8 (60 d), rise from 0.8 to 1.1 (40 d), - decline from 1.1 to 0.85 (20 d). + Phases: decline from 1.2 to 0.8 (63 d), rise from 0.8 to 1.1 (42 d), + decline from 1.1 to 0.85 (21 d). Total 126 days = 18 weeks. Returns ------- @@ -143,9 +149,9 @@ def build_true_rt() -> np.ndarray: R(t) trajectory. """ segments = [ - (60, 1.2, 0.8), - (40, 0.8, 1.1), - (20, 1.1, 0.85), + (63, 1.2, 0.8), + (42, 0.8, 1.1), + (21, 1.1, 0.85), ] rt = np.concatenate( [ @@ -255,21 +261,21 @@ def sample_negbinom( return rng.negative_binomial(n=concentration, p=p) -def aggregate_to_epiweeks( - daily_values: np.ndarray, +def build_weekly_hosp_frame( + weekly_values: np.ndarray, start_date: date, ) -> pl.DataFrame: """ - Aggregate daily counts to MMWR epiweek totals (Sun-Sat, labeled by Saturday). - - Only complete 7-day weeks are kept. + Build a weekly hospital admissions frame with MMWR week-ending dates. Parameters ---------- - daily_values : np.ndarray - Daily count time series. + weekly_values : np.ndarray + One value per MMWR epiweek (aggregated by the caller). start_date : date - Date of the first element. + Date of day 0 of the daily time series the weekly values were + aggregated from. Used to compute the date of the first + Saturday (week-ending day). Returns ------- @@ -277,9 +283,6 @@ def aggregate_to_epiweeks( Columns: week_end, weekly_hosp_admits. """ first_dow = start_date.weekday() - weekly_values = daily_to_mmwr_epiweekly( - np.asarray(daily_values), input_data_first_dow=first_dow - ) days_to_first_sunday = (6 - first_dow) % 7 first_week_end = start_date + timedelta(days=days_to_first_sunday + 6) n_weeks = len(weekly_values) @@ -319,7 +322,18 @@ def generate() -> None: expected_hosp_daily, NEGBINOM_CONCENTRATION_HOSP, rng ) - weekly_hosp = aggregate_to_epiweeks(hosp_daily_obs, START_DATE) + # Weekly hospital admissions: aggregate daily *expected* values to MMWR + # epiweeks, then apply one NegBin at the weekly scale. This matches the + # generative assumption of PopulationCounts(aggregation_period=7) with a + # single NegativeBinomialNoise at the reporting cadence. + expected_hosp_weekly = np.asarray( + daily_to_mmwr_epiweekly(expected_hosp_daily, input_data_first_dow=first_dow) + ) + expected_hosp_weekly = np.maximum(expected_hosp_weekly, 1.0) + hosp_weekly_obs = sample_negbinom( + expected_hosp_weekly, NEGBINOM_CONCENTRATION_HOSP_WEEKLY, rng + ) + weekly_hosp = build_weekly_hosp_frame(hosp_weekly_obs, START_DATE) hosp_daily_df = pl.DataFrame( { @@ -380,14 +394,15 @@ def generate() -> None: "generation_interval_pmf": GEN_INT_PMF.tolist(), "i0_per_capita": I0_PER_CAPITA, "rt_trajectory": { - "phase_1": {"days": 60, "start": 1.2, "end": 0.8}, - "phase_2": {"days": 40, "start": 0.8, "end": 1.1}, - "phase_3": {"days": 20, "start": 1.1, "end": 0.85}, + "phase_1": {"days": 63, "start": 1.2, "end": 0.8}, + "phase_2": {"days": 42, "start": 0.8, "end": 1.1}, + "phase_3": {"days": 21, "start": 1.1, "end": 0.85}, }, "hospitalizations": { "ihr": IHR, "delay_pmf_source": "infection_admission_interval.tsv", - "negbinom_concentration": NEGBINOM_CONCENTRATION_HOSP, + "negbinom_concentration_daily": NEGBINOM_CONCENTRATION_HOSP, + "negbinom_concentration_weekly": NEGBINOM_CONCENTRATION_HOSP_WEEKLY, "temporal_resolutions": ["daily", "weekly_epiweek"], }, "ed_visits": { diff --git a/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv b/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv deleted file mode 100644 index 8bfdc09a..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,geo_value,disease,ed_visits -2023-11-06,CA,COVID-19,144 -2023-11-07,CA,COVID-19,120 -2023-11-08,CA,COVID-19,151 -2023-11-09,CA,COVID-19,154 -2023-11-10,CA,COVID-19,134 -2023-11-11,CA,COVID-19,133 -2023-11-12,CA,COVID-19,158 -2023-11-13,CA,COVID-19,189 -2023-11-14,CA,COVID-19,208 -2023-11-15,CA,COVID-19,307 -2023-11-16,CA,COVID-19,259 -2023-11-17,CA,COVID-19,350 -2023-11-18,CA,COVID-19,369 -2023-11-19,CA,COVID-19,381 -2023-11-20,CA,COVID-19,401 -2023-11-21,CA,COVID-19,629 -2023-11-22,CA,COVID-19,449 -2023-11-23,CA,COVID-19,551 -2023-11-24,CA,COVID-19,422 -2023-11-25,CA,COVID-19,409 -2023-11-26,CA,COVID-19,349 -2023-11-27,CA,COVID-19,593 -2023-11-28,CA,COVID-19,740 -2023-11-29,CA,COVID-19,700 -2023-11-30,CA,COVID-19,767 -2023-12-01,CA,COVID-19,703 -2023-12-02,CA,COVID-19,688 -2023-12-03,CA,COVID-19,604 -2023-12-04,CA,COVID-19,1018 -2023-12-05,CA,COVID-19,814 -2023-12-06,CA,COVID-19,851 -2023-12-07,CA,COVID-19,890 -2023-12-08,CA,COVID-19,644 -2023-12-09,CA,COVID-19,711 -2023-12-10,CA,COVID-19,614 -2023-12-11,CA,COVID-19,1092 -2023-12-12,CA,COVID-19,815 -2023-12-13,CA,COVID-19,539 -2023-12-14,CA,COVID-19,729 -2023-12-15,CA,COVID-19,711 -2023-12-16,CA,COVID-19,638 -2023-12-17,CA,COVID-19,502 -2023-12-18,CA,COVID-19,894 -2023-12-19,CA,COVID-19,716 -2023-12-20,CA,COVID-19,599 -2023-12-21,CA,COVID-19,536 -2023-12-22,CA,COVID-19,605 -2023-12-23,CA,COVID-19,354 -2023-12-24,CA,COVID-19,620 -2023-12-25,CA,COVID-19,708 -2023-12-26,CA,COVID-19,485 -2023-12-27,CA,COVID-19,528 -2023-12-28,CA,COVID-19,463 -2023-12-29,CA,COVID-19,322 -2023-12-30,CA,COVID-19,166 -2023-12-31,CA,COVID-19,178 -2024-01-01,CA,COVID-19,296 -2024-01-02,CA,COVID-19,299 -2024-01-03,CA,COVID-19,219 -2024-01-04,CA,COVID-19,188 -2024-01-05,CA,COVID-19,147 -2024-01-06,CA,COVID-19,68 -2024-01-07,CA,COVID-19,98 -2024-01-08,CA,COVID-19,113 -2024-01-09,CA,COVID-19,117 -2024-01-10,CA,COVID-19,74 -2024-01-11,CA,COVID-19,69 -2024-01-12,CA,COVID-19,84 -2024-01-13,CA,COVID-19,58 -2024-01-14,CA,COVID-19,26 -2024-01-15,CA,COVID-19,42 -2024-01-16,CA,COVID-19,53 -2024-01-17,CA,COVID-19,62 -2024-01-18,CA,COVID-19,42 -2024-01-19,CA,COVID-19,47 -2024-01-20,CA,COVID-19,32 -2024-01-21,CA,COVID-19,29 -2024-01-22,CA,COVID-19,19 -2024-01-23,CA,COVID-19,21 -2024-01-24,CA,COVID-19,28 -2024-01-25,CA,COVID-19,20 -2024-01-26,CA,COVID-19,9 -2024-01-27,CA,COVID-19,16 -2024-01-28,CA,COVID-19,15 -2024-01-29,CA,COVID-19,34 -2024-01-30,CA,COVID-19,30 -2024-01-31,CA,COVID-19,21 -2024-02-01,CA,COVID-19,19 -2024-02-02,CA,COVID-19,18 -2024-02-03,CA,COVID-19,24 -2024-02-04,CA,COVID-19,10 -2024-02-05,CA,COVID-19,25 -2024-02-06,CA,COVID-19,16 -2024-02-07,CA,COVID-19,27 -2024-02-08,CA,COVID-19,27 -2024-02-09,CA,COVID-19,29 -2024-02-10,CA,COVID-19,21 -2024-02-11,CA,COVID-19,17 -2024-02-12,CA,COVID-19,26 -2024-02-13,CA,COVID-19,25 -2024-02-14,CA,COVID-19,27 -2024-02-15,CA,COVID-19,24 -2024-02-16,CA,COVID-19,20 -2024-02-17,CA,COVID-19,21 -2024-02-18,CA,COVID-19,38 -2024-02-19,CA,COVID-19,41 -2024-02-20,CA,COVID-19,38 -2024-02-21,CA,COVID-19,42 -2024-02-22,CA,COVID-19,41 -2024-02-23,CA,COVID-19,32 -2024-02-24,CA,COVID-19,32 -2024-02-25,CA,COVID-19,27 -2024-02-26,CA,COVID-19,54 -2024-02-27,CA,COVID-19,45 -2024-02-28,CA,COVID-19,41 -2024-02-29,CA,COVID-19,26 -2024-03-01,CA,COVID-19,23 -2024-03-02,CA,COVID-19,19 -2024-03-03,CA,COVID-19,18 -2024-03-04,CA,COVID-19,24 diff --git a/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv deleted file mode 100644 index fd32f156..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,geo_value,daily_hosp_admits,pop -2023-11-06,CA,28,39512223 -2023-11-07,CA,31,39512223 -2023-11-08,CA,43,39512223 -2023-11-09,CA,30,39512223 -2023-11-10,CA,57,39512223 -2023-11-11,CA,47,39512223 -2023-11-12,CA,66,39512223 -2023-11-13,CA,94,39512223 -2023-11-14,CA,66,39512223 -2023-11-15,CA,81,39512223 -2023-11-16,CA,93,39512223 -2023-11-17,CA,121,39512223 -2023-11-18,CA,146,39512223 -2023-11-19,CA,118,39512223 -2023-11-20,CA,175,39512223 -2023-11-21,CA,208,39512223 -2023-11-22,CA,197,39512223 -2023-11-23,CA,223,39512223 -2023-11-24,CA,226,39512223 -2023-11-25,CA,257,39512223 -2023-11-26,CA,276,39512223 -2023-11-27,CA,276,39512223 -2023-11-28,CA,300,39512223 -2023-11-29,CA,321,39512223 -2023-11-30,CA,328,39512223 -2023-12-01,CA,351,39512223 -2023-12-02,CA,354,39512223 -2023-12-03,CA,390,39512223 -2023-12-04,CA,441,39512223 -2023-12-05,CA,415,39512223 -2023-12-06,CA,381,39512223 -2023-12-07,CA,413,39512223 -2023-12-08,CA,425,39512223 -2023-12-09,CA,418,39512223 -2023-12-10,CA,478,39512223 -2023-12-11,CA,491,39512223 -2023-12-12,CA,514,39512223 -2023-12-13,CA,489,39512223 -2023-12-14,CA,423,39512223 -2023-12-15,CA,509,39512223 -2023-12-16,CA,560,39512223 -2023-12-17,CA,457,39512223 -2023-12-18,CA,493,39512223 -2023-12-19,CA,537,39512223 -2023-12-20,CA,505,39512223 -2023-12-21,CA,493,39512223 -2023-12-22,CA,430,39512223 -2023-12-23,CA,495,39512223 -2023-12-24,CA,468,39512223 -2023-12-25,CA,423,39512223 -2023-12-26,CA,419,39512223 -2023-12-27,CA,377,39512223 -2023-12-28,CA,356,39512223 -2023-12-29,CA,395,39512223 -2023-12-30,CA,390,39512223 -2023-12-31,CA,295,39512223 -2024-01-01,CA,283,39512223 -2024-01-02,CA,333,39512223 -2024-01-03,CA,244,39512223 -2024-01-04,CA,271,39512223 -2024-01-05,CA,235,39512223 -2024-01-06,CA,237,39512223 -2024-01-07,CA,228,39512223 -2024-01-08,CA,221,39512223 -2024-01-09,CA,151,39512223 -2024-01-10,CA,168,39512223 -2024-01-11,CA,166,39512223 -2024-01-12,CA,133,39512223 -2024-01-13,CA,125,39512223 -2024-01-14,CA,105,39512223 -2024-01-15,CA,105,39512223 -2024-01-16,CA,95,39512223 -2024-01-17,CA,80,39512223 -2024-01-18,CA,79,39512223 -2024-01-19,CA,57,39512223 -2024-01-20,CA,60,39512223 -2024-01-21,CA,51,39512223 -2024-01-22,CA,49,39512223 -2024-01-23,CA,30,39512223 -2024-01-24,CA,43,39512223 -2024-01-25,CA,38,39512223 -2024-01-26,CA,39,39512223 -2024-01-27,CA,35,39512223 -2024-01-28,CA,25,39512223 -2024-01-29,CA,19,39512223 -2024-01-30,CA,20,39512223 -2024-01-31,CA,15,39512223 -2024-02-01,CA,20,39512223 -2024-02-02,CA,15,39512223 -2024-02-03,CA,15,39512223 -2024-02-04,CA,9,39512223 -2024-02-05,CA,12,39512223 -2024-02-06,CA,9,39512223 -2024-02-07,CA,16,39512223 -2024-02-08,CA,19,39512223 -2024-02-09,CA,22,39512223 -2024-02-10,CA,13,39512223 -2024-02-11,CA,15,39512223 -2024-02-12,CA,15,39512223 -2024-02-13,CA,13,39512223 -2024-02-14,CA,12,39512223 -2024-02-15,CA,13,39512223 -2024-02-16,CA,22,39512223 -2024-02-17,CA,19,39512223 -2024-02-18,CA,20,39512223 -2024-02-19,CA,15,39512223 -2024-02-20,CA,15,39512223 -2024-02-21,CA,15,39512223 -2024-02-22,CA,19,39512223 -2024-02-23,CA,16,39512223 -2024-02-24,CA,15,39512223 -2024-02-25,CA,23,39512223 -2024-02-26,CA,20,39512223 -2024-02-27,CA,19,39512223 -2024-02-28,CA,22,39512223 -2024-02-29,CA,22,39512223 -2024-03-01,CA,28,39512223 -2024-03-02,CA,21,39512223 -2024-03-03,CA,20,39512223 -2024-03-04,CA,27,39512223 diff --git a/pyrenew/datasets/synthetic_CA_120/daily_infections.csv b/pyrenew/datasets/synthetic_CA_120/daily_infections.csv deleted file mode 100644 index d22e74f1..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_infections.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,true_infections,true_rt -2023-11-06,19756.11054520949,1.2 -2023-11-07,22137.832061728463,1.1933333333333334 -2023-11-08,24713.05529723112,1.1866666666666665 -2023-11-09,27483.293945028956,1.18 -2023-11-10,30447.758878007327,1.1733333333333333 -2023-11-11,33603.09179301876,1.1666666666666665 -2023-11-12,36943.12084926061,1.16 -2023-11-13,40458.647149085045,1.1533333333333333 -2023-11-14,44137.271061809806,1.1466666666666667 -2023-11-15,47963.348602612365,1.14 -2023-11-16,51917.84793339851,1.1333333333333333 -2023-11-17,55978.37022993391,1.1266666666666667 -2023-11-18,60119.22285114966,1.1199999999999999 -2023-11-19,64311.56901533906,1.1133333333333333 -2023-11-20,68523.65715311568,1.1066666666666667 -2023-11-21,72721.13103142017,1.1 -2023-11-22,76867.41944068302,1.0933333333333333 -2023-11-23,80924.2017246116,1.0866666666666667 -2023-11-24,84851.94303922662,1.08 -2023-11-25,88610.49054041045,1.0733333333333333 -2023-11-26,92159.71934937873,1.0666666666666667 -2023-11-27,95460.2149212564,1.06 -2023-11-28,98473.97653721012,1.0533333333333332 -2023-11-29,101165.12513683562,1.0466666666666666 -2023-11-30,103500.59768685527,1.04 -2023-12-01,105450.80981140169,1.0333333333333332 -2023-12-02,106990.26853536024,1.0266666666666666 -2023-12-03,108098.11774003913,1.02 -2023-12-04,108758.6002989673,1.0133333333333334 -2023-12-05,108961.42282384375,1.0066666666666666 -2023-12-06,108702.01145320752,1.0 -2023-12-07,107981.65008103396,0.9933333333333333 -2023-12-08,106807.49574885165,0.9866666666666667 -2023-12-09,105192.46949463306,0.98 -2023-12-10,103155.02463303752,0.9733333333333334 -2023-12-11,100718.79809568585,0.9666666666666667 -2023-12-12,97912.15394717995,0.96 -2023-12-13,94767.63137826548,0.9533333333333334 -2023-12-14,91321.31223948221,0.9466666666666667 -2023-12-15,87612.1254121439,0.94 -2023-12-16,83681.10693644159,0.9333333333333333 -2023-12-17,79570.63577318197,0.9266666666666667 -2023-12-18,75323.66533931163,0.9199999999999999 -2023-12-19,70982.97053090445,0.9133333333333333 -2023-12-20,66590.42886290187,0.9066666666666667 -2023-12-21,62186.35267198246,0.9 -2023-12-22,57808.88713067756,0.8933333333333333 -2023-12-23,53493.48620967932,0.8866666666666667 -2023-12-24,49272.47581752002,0.88 -2023-12-25,45174.71026669215,0.8733333333333333 -2023-12-26,41225.32508890801,0.8666666666666667 -2023-12-27,37445.58617150747,0.8600000000000001 -2023-12-28,33852.8323243111,0.8533333333333333 -2023-12-29,30460.50580940868,0.8466666666666667 -2023-12-30,27278.263155260243,0.8400000000000001 -2023-12-31,24312.156790105855,0.8333333333333334 -2024-01-01,21564.876704915074,0.8266666666666667 -2024-01-02,19036.040507503185,0.8200000000000001 -2024-01-03,16722.519850180324,0.8133333333333334 -2024-01-04,14618.791277181628,0.8066666666666666 -2024-01-05,12717.300002393937,0.8 -2024-01-06,11205.412114101426,0.8075 -2024-01-07,9925.754783687535,0.8150000000000001 -2024-01-08,8838.84684474518,0.8225 -2024-01-09,7912.576007747334,0.8300000000000001 -2024-01-10,7120.727700272463,0.8375 -2024-01-11,6441.825882807316,0.8450000000000001 -2024-01-12,5858.216338610594,0.8525 -2024-01-13,5355.339128127326,0.8600000000000001 -2024-01-14,4921.050758957707,0.8675 -2024-01-15,4545.338518068657,0.875 -2024-01-16,4219.883804499359,0.8825000000000001 -2024-01-17,3937.758638554082,0.8900000000000001 -2024-01-18,3693.1812564367424,0.8975000000000001 -2024-01-19,3481.3187502359347,0.905 -2024-01-20,3298.127311750783,0.9125000000000001 -2024-01-21,3140.222655134437,0.92 -2024-01-22,3004.7748166431147,0.9275000000000001 -2024-01-23,2889.422557099414,0.935 -2024-01-24,2792.2038372695265,0.9425000000000001 -2024-01-25,2711.4994312736244,0.9500000000000001 -2024-01-26,2645.9873791490927,0.9575 -2024-01-27,2594.6064611590755,0.9650000000000001 -2024-01-28,2556.5272634637226,0.9725000000000001 -2024-01-29,2531.1297188751973,0.9800000000000001 -2024-01-30,2517.9862644961117,0.9875 -2024-01-31,2516.8499736912167,0.9950000000000001 -2024-02-01,2527.647203642624,1.0025000000000002 -2024-02-02,2550.4744608668757,1.01 -2024-02-03,2585.599332958829,1.0175 -2024-02-04,2633.4654720916783,1.0250000000000001 -2024-02-05,2694.7017504746477,1.0325000000000002 -2024-02-06,2770.13584583844,1.04 -2024-02-07,2860.8126619369277,1.0475 -2024-02-08,2968.0181512451327,1.0550000000000002 -2024-02-09,3093.3092914120125,1.0625 -2024-02-10,3238.551181553564,1.07 -2024-02-11,3405.962478571058,1.0775000000000001 -2024-02-12,3598.1706987346743,1.085 -2024-02-13,3818.2792797395227,1.0925 -2024-02-14,4069.948750608039,1.1 -2024-02-15,4278.804259714245,1.0875000000000001 -2024-02-16,4464.378292782625,1.0750000000000002 -2024-02-17,4622.498189584832,1.0625 -2024-02-18,4749.429210576598,1.05 -2024-02-19,4842.016091859674,1.0375 -2024-02-20,4897.808898382193,1.0250000000000001 -2024-02-21,4915.166344392091,1.0125 -2024-02-22,4893.330641693712,1.0 -2024-02-23,4832.51409477793,0.9875 -2024-02-24,4733.84928467361,0.9750000000000001 -2024-02-25,4599.3707190434,0.9625 -2024-02-26,4431.94639781228,0.95 -2024-02-27,4235.176372034393,0.9375 -2024-02-28,4013.263271199615,0.925 -2024-02-29,3770.8613226904004,0.9125 -2024-03-01,3512.911514102406,0.9 -2024-03-02,3244.47115210662,0.8875 -2024-03-03,2970.5462345139176,0.875 -2024-03-04,2695.934508618182,0.8625 diff --git a/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv deleted file mode 100644 index 0b1d3b8f..00000000 --- a/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv +++ /dev/null @@ -1,17 +0,0 @@ -week_end,weekly_hosp_admits,location,pop -2023-11-18,667,CA,39512223 -2023-11-25,1404,CA,39512223 -2023-12-02,2206,CA,39512223 -2023-12-09,2883,CA,39512223 -2023-12-16,3464,CA,39512223 -2023-12-23,3410,CA,39512223 -2023-12-30,2828,CA,39512223 -2024-01-06,1898,CA,39512223 -2024-01-13,1192,CA,39512223 -2024-01-20,581,CA,39512223 -2024-01-27,285,CA,39512223 -2024-02-03,129,CA,39512223 -2024-02-10,100,CA,39512223 -2024-02-17,109,CA,39512223 -2024-02-24,115,CA,39512223 -2024-03-02,155,CA,39512223 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv b/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv new file mode 100644 index 00000000..3f4c4651 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv @@ -0,0 +1,127 @@ +date,geo_value,disease,ed_visits +2023-11-05,CA,COVID-19,95 +2023-11-06,CA,COVID-19,125 +2023-11-07,CA,COVID-19,138 +2023-11-08,CA,COVID-19,136 +2023-11-09,CA,COVID-19,172 +2023-11-10,CA,COVID-19,161 +2023-11-11,CA,COVID-19,155 +2023-11-12,CA,COVID-19,133 +2023-11-13,CA,COVID-19,273 +2023-11-14,CA,COVID-19,263 +2023-11-15,CA,COVID-19,360 +2023-11-16,CA,COVID-19,296 +2023-11-17,CA,COVID-19,202 +2023-11-18,CA,COVID-19,272 +2023-11-19,CA,COVID-19,312 +2023-11-20,CA,COVID-19,510 +2023-11-21,CA,COVID-19,448 +2023-11-22,CA,COVID-19,591 +2023-11-23,CA,COVID-19,529 +2023-11-24,CA,COVID-19,476 +2023-11-25,CA,COVID-19,403 +2023-11-26,CA,COVID-19,536 +2023-11-27,CA,COVID-19,603 +2023-11-28,CA,COVID-19,1160 +2023-11-29,CA,COVID-19,989 +2023-11-30,CA,COVID-19,769 +2023-12-01,CA,COVID-19,901 +2023-12-02,CA,COVID-19,750 +2023-12-03,CA,COVID-19,617 +2023-12-04,CA,COVID-19,651 +2023-12-05,CA,COVID-19,729 +2023-12-06,CA,COVID-19,885 +2023-12-07,CA,COVID-19,1026 +2023-12-08,CA,COVID-19,792 +2023-12-09,CA,COVID-19,675 +2023-12-10,CA,COVID-19,772 +2023-12-11,CA,COVID-19,896 +2023-12-12,CA,COVID-19,623 +2023-12-13,CA,COVID-19,786 +2023-12-14,CA,COVID-19,777 +2023-12-15,CA,COVID-19,904 +2023-12-16,CA,COVID-19,561 +2023-12-17,CA,COVID-19,531 +2023-12-18,CA,COVID-19,983 +2023-12-19,CA,COVID-19,973 +2023-12-20,CA,COVID-19,583 +2023-12-21,CA,COVID-19,504 +2023-12-22,CA,COVID-19,660 +2023-12-23,CA,COVID-19,647 +2023-12-24,CA,COVID-19,490 +2023-12-25,CA,COVID-19,721 +2023-12-26,CA,COVID-19,659 +2023-12-27,CA,COVID-19,479 +2023-12-28,CA,COVID-19,352 +2023-12-29,CA,COVID-19,348 +2023-12-30,CA,COVID-19,320 +2023-12-31,CA,COVID-19,229 +2024-01-01,CA,COVID-19,269 +2024-01-02,CA,COVID-19,309 +2024-01-03,CA,COVID-19,326 +2024-01-04,CA,COVID-19,277 +2024-01-05,CA,COVID-19,235 +2024-01-06,CA,COVID-19,167 +2024-01-07,CA,COVID-19,116 +2024-01-08,CA,COVID-19,168 +2024-01-09,CA,COVID-19,129 +2024-01-10,CA,COVID-19,165 +2024-01-11,CA,COVID-19,91 +2024-01-12,CA,COVID-19,91 +2024-01-13,CA,COVID-19,54 +2024-01-14,CA,COVID-19,68 +2024-01-15,CA,COVID-19,88 +2024-01-16,CA,COVID-19,82 +2024-01-17,CA,COVID-19,64 +2024-01-18,CA,COVID-19,47 +2024-01-19,CA,COVID-19,40 +2024-01-20,CA,COVID-19,30 +2024-01-21,CA,COVID-19,28 +2024-01-22,CA,COVID-19,34 +2024-01-23,CA,COVID-19,27 +2024-01-24,CA,COVID-19,27 +2024-01-25,CA,COVID-19,43 +2024-01-26,CA,COVID-19,29 +2024-01-27,CA,COVID-19,21 +2024-01-28,CA,COVID-19,20 +2024-01-29,CA,COVID-19,28 +2024-01-30,CA,COVID-19,20 +2024-01-31,CA,COVID-19,23 +2024-02-01,CA,COVID-19,19 +2024-02-02,CA,COVID-19,23 +2024-02-03,CA,COVID-19,16 +2024-02-04,CA,COVID-19,15 +2024-02-05,CA,COVID-19,13 +2024-02-06,CA,COVID-19,13 +2024-02-07,CA,COVID-19,13 +2024-02-08,CA,COVID-19,13 +2024-02-09,CA,COVID-19,11 +2024-02-10,CA,COVID-19,9 +2024-02-11,CA,COVID-19,13 +2024-02-12,CA,COVID-19,9 +2024-02-13,CA,COVID-19,36 +2024-02-14,CA,COVID-19,25 +2024-02-15,CA,COVID-19,18 +2024-02-16,CA,COVID-19,26 +2024-02-17,CA,COVID-19,18 +2024-02-18,CA,COVID-19,23 +2024-02-19,CA,COVID-19,25 +2024-02-20,CA,COVID-19,34 +2024-02-21,CA,COVID-19,20 +2024-02-22,CA,COVID-19,21 +2024-02-23,CA,COVID-19,32 +2024-02-24,CA,COVID-19,24 +2024-02-25,CA,COVID-19,27 +2024-02-26,CA,COVID-19,36 +2024-02-27,CA,COVID-19,44 +2024-02-28,CA,COVID-19,39 +2024-02-29,CA,COVID-19,46 +2024-03-01,CA,COVID-19,23 +2024-03-02,CA,COVID-19,35 +2024-03-03,CA,COVID-19,29 +2024-03-04,CA,COVID-19,34 +2024-03-05,CA,COVID-19,27 +2024-03-06,CA,COVID-19,39 +2024-03-07,CA,COVID-19,31 +2024-03-08,CA,COVID-19,27 +2024-03-09,CA,COVID-19,13 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv new file mode 100644 index 00000000..b36da6d5 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv @@ -0,0 +1,127 @@ +date,geo_value,daily_hosp_admits,pop +2023-11-05,CA,28,39512223 +2023-11-06,CA,31,39512223 +2023-11-07,CA,43,39512223 +2023-11-08,CA,30,39512223 +2023-11-09,CA,57,39512223 +2023-11-10,CA,47,39512223 +2023-11-11,CA,66,39512223 +2023-11-12,CA,94,39512223 +2023-11-13,CA,66,39512223 +2023-11-14,CA,81,39512223 +2023-11-15,CA,93,39512223 +2023-11-16,CA,122,39512223 +2023-11-17,CA,146,39512223 +2023-11-18,CA,119,39512223 +2023-11-19,CA,176,39512223 +2023-11-20,CA,209,39512223 +2023-11-21,CA,199,39512223 +2023-11-22,CA,225,39512223 +2023-11-23,CA,228,39512223 +2023-11-24,CA,260,39512223 +2023-11-25,CA,280,39512223 +2023-11-26,CA,281,39512223 +2023-11-27,CA,306,39512223 +2023-11-28,CA,328,39512223 +2023-11-29,CA,336,39512223 +2023-11-30,CA,361,39512223 +2023-12-01,CA,365,39512223 +2023-12-02,CA,403,39512223 +2023-12-03,CA,457,39512223 +2023-12-04,CA,433,39512223 +2023-12-05,CA,399,39512223 +2023-12-06,CA,434,39512223 +2023-12-07,CA,449,39512223 +2023-12-08,CA,442,39512223 +2023-12-09,CA,507,39512223 +2023-12-10,CA,524,39512223 +2023-12-11,CA,552,39512223 +2023-12-12,CA,528,39512223 +2023-12-13,CA,460,39512223 +2023-12-14,CA,554,39512223 +2023-12-15,CA,611,39512223 +2023-12-16,CA,506,39512223 +2023-12-17,CA,546,39512223 +2023-12-18,CA,598,39512223 +2023-12-19,CA,566,39512223 +2023-12-20,CA,554,39512223 +2023-12-21,CA,490,39512223 +2023-12-22,CA,566,39512223 +2023-12-23,CA,537,39512223 +2023-12-24,CA,491,39512223 +2023-12-25,CA,491,39512223 +2023-12-26,CA,445,39512223 +2023-12-27,CA,424,39512223 +2023-12-28,CA,470,39512223 +2023-12-29,CA,467,39512223 +2023-12-30,CA,373,39512223 +2023-12-31,CA,449,39512223 +2024-01-01,CA,347,39512223 +2024-01-02,CA,317,39512223 +2024-01-03,CA,361,39512223 +2024-01-04,CA,290,39512223 +2024-01-05,CA,269,39512223 +2024-01-06,CA,238,39512223 +2024-01-07,CA,224,39512223 +2024-01-08,CA,201,39512223 +2024-01-09,CA,185,39512223 +2024-01-10,CA,199,39512223 +2024-01-11,CA,179,39512223 +2024-01-12,CA,168,39512223 +2024-01-13,CA,143,39512223 +2024-01-14,CA,142,39512223 +2024-01-15,CA,128,39512223 +2024-01-16,CA,109,39512223 +2024-01-17,CA,108,39512223 +2024-01-18,CA,79,39512223 +2024-01-19,CA,81,39512223 +2024-01-20,CA,69,39512223 +2024-01-21,CA,66,39512223 +2024-01-22,CA,42,39512223 +2024-01-23,CA,57,39512223 +2024-01-24,CA,50,39512223 +2024-01-25,CA,50,39512223 +2024-01-26,CA,45,39512223 +2024-01-27,CA,33,39512223 +2024-01-28,CA,25,39512223 +2024-01-29,CA,25,39512223 +2024-01-30,CA,19,39512223 +2024-01-31,CA,23,39512223 +2024-02-01,CA,18,39512223 +2024-02-02,CA,17,39512223 +2024-02-03,CA,11,39512223 +2024-02-04,CA,13,39512223 +2024-02-05,CA,10,39512223 +2024-02-06,CA,17,39512223 +2024-02-07,CA,9,39512223 +2024-02-08,CA,13,39512223 +2024-02-09,CA,14,39512223 +2024-02-10,CA,14,39512223 +2024-02-11,CA,14,39512223 +2024-02-12,CA,11,39512223 +2024-02-13,CA,10,39512223 +2024-02-14,CA,11,39512223 +2024-02-15,CA,20,39512223 +2024-02-16,CA,17,39512223 +2024-02-17,CA,17,39512223 +2024-02-18,CA,12,39512223 +2024-02-19,CA,12,39512223 +2024-02-20,CA,11,39512223 +2024-02-21,CA,14,39512223 +2024-02-22,CA,12,39512223 +2024-02-23,CA,11,39512223 +2024-02-24,CA,18,39512223 +2024-02-25,CA,15,39512223 +2024-02-26,CA,15,39512223 +2024-02-27,CA,18,39512223 +2024-02-28,CA,23,39512223 +2024-02-29,CA,13,39512223 +2024-03-01,CA,16,39512223 +2024-03-02,CA,23,39512223 +2024-03-03,CA,13,39512223 +2024-03-04,CA,24,39512223 +2024-03-05,CA,23,39512223 +2024-03-06,CA,26,39512223 +2024-03-07,CA,25,39512223 +2024-03-08,CA,19,39512223 +2024-03-09,CA,21,39512223 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_infections.csv b/pyrenew/datasets/synthetic_CA_126/daily_infections.csv new file mode 100644 index 00000000..461579ba --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_infections.csv @@ -0,0 +1,127 @@ +date,true_infections,true_rt +2023-11-05,19756.11054520949,1.2 +2023-11-06,22143.721349321364,1.1936507936507936 +2023-11-07,24730.701954748823,1.1873015873015873 +2023-11-08,27520.279922418667,1.180952380952381 +2023-11-09,30513.614670228315,1.1746031746031746 +2023-11-10,33709.526838297024,1.1682539682539683 +2023-11-11,37104.241812393724,1.1619047619047618 +2023-11-12,40691.15530610575,1.1555555555555554 +2023-11-13,44460.6292337521,1.1492063492063491 +2023-11-14,48399.903877935554,1.1428571428571428 +2023-11-15,52492.90723467809,1.1365079365079365 +2023-11-16,56720.213937022985,1.1301587301587301 +2023-11-17,61059.047787575444,1.1238095238095238 +2023-11-18,65483.352178571426,1.1174603174603175 +2023-11-19,69963.9330002031,1.1111111111111112 +2023-11-20,74468.67701680085,1.1047619047619048 +2023-11-21,78962.84682237565,1.0984126984126985 +2023-11-22,83409.4513679222,1.092063492063492 +2023-11-23,87769.68897968013,1.0857142857142856 +2023-11-24,92003.45734561312,1.0793650793650793 +2023-11-25,96069.92272472705,1.073015873015873 +2023-11-26,99928.13841545957,1.0666666666666667 +2023-11-27,103537.7004820302,1.0603174603174603 +2023-11-28,106859.4269560898,1.053968253968254 +2023-11-29,109856.04528657829,1.0476190476190477 +2023-11-30,112492.87177639856,1.0412698412698413 +2023-12-01,114738.46618109582,1.0349206349206348 +2023-12-02,116565.24459766244,1.0285714285714285 +2023-12-03,117950.03426758068,1.0222222222222221 +2023-12-04,118874.55496352536,1.0158730158730158 +2023-12-05,119325.81320807226,1.0095238095238095 +2023-12-06,119296.39764740864,1.0031746031746032 +2023-12-07,118784.66641411892,0.9968253968253968 +2023-12-08,117794.82018176407,0.9904761904761905 +2023-12-09,116336.85774463006,0.9841269841269842 +2023-12-10,114426.41424020738,0.9777777777777779 +2023-12-11,112084.48545277533,0.9714285714285714 +2023-12-12,109337.04487359013,0.9650793650793651 +2023-12-13,106214.56322819905,0.9587301587301588 +2023-12-14,102751.4429030742,0.9523809523809524 +2023-12-15,98985.38201301648,0.946031746031746 +2023-12-16,94956.68466518822,0.9396825396825397 +2023-12-17,90707.5352330704,0.9333333333333333 +2023-12-18,86281.25511496875,0.926984126984127 +2023-12-19,81721.56050233734,0.9206349206349207 +2023-12-20,77071.83913347754,0.9142857142857144 +2023-12-21,72374.46289250287,0.907936507936508 +2023-12-22,67670.15148828148,0.9015873015873016 +2023-12-23,62997.40038884645,0.8952380952380953 +2023-12-24,58391.98378409807,0.888888888888889 +2023-12-25,53886.5407047299,0.8825396825396825 +2023-12-26,49510.249645148186,0.8761904761904762 +2023-12-27,45288.594230426665,0.8698412698412699 +2023-12-28,41243.219735682826,0.8634920634920635 +2023-12-29,37391.87770582045,0.8571428571428572 +2023-12-30,33748.45361723469,0.8507936507936509 +2023-12-31,30323.07053841518,0.8444444444444446 +2024-01-01,27122.260133653966,0.8380952380952382 +2024-01-02,24149.191145027726,0.8317460317460318 +2024-01-03,21403.944695610335,0.8253968253968255 +2024-01-04,18883.82537674655,0.8190476190476191 +2024-01-05,16583.69709315589,0.8126984126984127 +2024-01-06,14496.333006565756,0.8063492063492064 +2024-01-07,12612.769595127733,0.8 +2024-01-08,11108.34092566445,0.8071428571428572 +2024-01-09,9832.930973744562,0.8142857142857143 +2024-01-10,8747.948967868044,0.8214285714285715 +2024-01-11,7821.924519498369,0.8285714285714286 +2024-01-12,7029.102928870267,0.8357142857142857 +2024-01-13,6348.334538289707,0.8428571428571429 +2024-01-14,5762.19348424236,0.8500000000000001 +2024-01-15,5256.276056958226,0.8571428571428572 +2024-01-16,4818.547575787828,0.8642857142857143 +2024-01-17,4439.060337562475,0.8714285714285714 +2024-01-18,4109.530594712926,0.8785714285714287 +2024-01-19,3823.0425101568535,0.8857142857142858 +2024-01-20,3573.8090602874067,0.8928571428571429 +2024-01-21,3356.978392072291,0.9 +2024-01-22,3168.476589950405,0.9071428571428573 +2024-01-23,3004.8797151710523,0.9142857142857144 +2024-01-24,2863.309519082173,0.9214285714285715 +2024-01-25,2741.348211451286,0.9285714285714286 +2024-01-26,2636.96884685248,0.9357142857142857 +2024-01-27,2548.4784620555247,0.942857142857143 +2024-01-28,2474.471704111163,0.9500000000000001 +2024-01-29,2413.7931481554974,0.9571428571428572 +2024-01-30,2365.5068722504175,0.9642857142857144 +2024-01-31,2328.872153958227,0.9714285714285715 +2024-02-01,2303.3243955688663,0.9785714285714286 +2024-02-02,2288.460584558283,0.9857142857142858 +2024-02-03,2284.028762651382,0.9928571428571429 +2024-02-04,2289.921119069131,1.0 +2024-02-05,2306.170447669907,1.0071428571428571 +2024-02-06,2332.9498193450536,1.0142857142857142 +2024-02-07,2370.5754250480477,1.0214285714285716 +2024-02-08,2419.5126456106736,1.0285714285714287 +2024-02-09,2480.3855061775553,1.0357142857142858 +2024-02-10,2553.9897797971385,1.042857142857143 +2024-02-11,2641.3101207486616,1.05 +2024-02-12,2743.5417382594255,1.0571428571428572 +2024-02-13,2862.1172706974176,1.0642857142857143 +2024-02-14,2998.739695320636,1.0714285714285716 +2024-02-15,3155.4223166337065,1.0785714285714287 +2024-02-16,3334.537126326037,1.0857142857142859 +2024-02-17,3538.873130663702,1.092857142857143 +2024-02-18,3771.7066107178266,1.1 +2024-02-19,3967.4339363506274,1.0880952380952382 +2024-02-20,4143.30298203913,1.0761904761904764 +2024-02-21,4295.597223486165,1.0642857142857143 +2024-02-22,4420.940146912601,1.0523809523809524 +2024-02-23,4516.41535496139,1.0404761904761906 +2024-02-24,4579.67554658712,1.0285714285714287 +2024-02-25,4609.03482820428,1.0166666666666666 +2024-02-26,4603.539398956749,1.0047619047619047 +2024-02-27,4563.052362067281,0.9928571428571429 +2024-02-28,4488.221220621856,0.980952380952381 +2024-02-29,4380.476066032479,0.969047619047619 +2024-03-01,4241.985599243413,0.9571428571428572 +2024-03-02,4075.5851128393456,0.9452380952380952 +2024-03-03,3884.679728108972,0.9333333333333333 +2024-03-04,3673.127496098455,0.9214285714285715 +2024-03-05,3445.108006398925,0.9095238095238095 +2024-03-06,3204.982810520835,0.8976190476190476 +2024-03-07,2957.1543245741104,0.8857142857142857 +2024-03-08,2705.9296732363823,0.8738095238095238 +2024-03-09,2455.3954664262037,0.861904761904762 diff --git a/pyrenew/datasets/synthetic_CA_120/true_parameters.json b/pyrenew/datasets/synthetic_CA_126/true_parameters.json similarity index 86% rename from pyrenew/datasets/synthetic_CA_120/true_parameters.json rename to pyrenew/datasets/synthetic_CA_126/true_parameters.json index fecf1091..a95fe24c 100644 --- a/pyrenew/datasets/synthetic_CA_120/true_parameters.json +++ b/pyrenew/datasets/synthetic_CA_126/true_parameters.json @@ -1,8 +1,8 @@ { "description": "True parameters used to generate synthetic 120-day CA data. All values are known ground truth for posterior recovery checks.", "population": 39512223, - "start_date": "2023-11-06", - "n_days": 120, + "start_date": "2023-11-05", + "n_days": 126, "n_init": 55, "rng_seed": 20240101, "generation_interval_pmf": [ @@ -17,17 +17,17 @@ "i0_per_capita": 0.0005, "rt_trajectory": { "phase_1": { - "days": 60, + "days": 63, "start": 1.2, "end": 0.8 }, "phase_2": { - "days": 40, + "days": 42, "start": 0.8, "end": 1.1 }, "phase_3": { - "days": 20, + "days": 21, "start": 1.1, "end": 0.85 } @@ -35,7 +35,8 @@ "hospitalizations": { "ihr": 0.005, "delay_pmf_source": "infection_admission_interval.tsv", - "negbinom_concentration": 350.0, + "negbinom_concentration_daily": 350.0, + "negbinom_concentration_weekly": 100.0, "temporal_resolutions": [ "daily", "weekly_epiweek" diff --git a/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv new file mode 100644 index 00000000..efc9c9d2 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv @@ -0,0 +1,19 @@ +week_end,weekly_hosp_admits,location,pop +2023-11-11,341,CA,39512223 +2023-11-18,624,CA,39512223 +2023-11-25,1223,CA,39512223 +2023-12-02,2450,CA,39512223 +2023-12-09,3227,CA,39512223 +2023-12-16,4341,CA,39512223 +2023-12-23,4710,CA,39512223 +2023-12-30,3924,CA,39512223 +2024-01-06,2067,CA,39512223 +2024-01-13,1627,CA,39512223 +2024-01-20,625,CA,39512223 +2024-01-27,336,CA,39512223 +2024-02-03,151,CA,39512223 +2024-02-10,93,CA,39512223 +2024-02-17,70,CA,39512223 +2024-02-24,79,CA,39512223 +2024-03-02,114,CA,39512223 +2024-03-09,147,CA,39512223 diff --git a/pyrenew/datasets/synthetic_data.py b/pyrenew/datasets/synthetic_data.py index 13bcfd37..9b944ceb 100644 --- a/pyrenew/datasets/synthetic_data.py +++ b/pyrenew/datasets/synthetic_data.py @@ -1,7 +1,7 @@ """ -Loaders for synthetic 120-day California test data. +Loaders for synthetic 126-day California test data. -The synthetic dataset is generated by ``datagen_he_CA_120.py`` and contains +The synthetic dataset is generated by ``datagen_he_CA_126.py`` and contains daily infections, hospital admissions, ED visits, and ground-truth parameters for a single-jurisdiction COVID-19 scenario. """ @@ -16,14 +16,14 @@ def _synthetic_path() -> files: """ - Return the path to the synthetic_CA_120 data directory. + Return the path to the synthetic_CA_126 data directory. Returns ------- importlib.resources.abc.Traversable - Path to the synthetic_CA_120 package directory. + Path to the synthetic_CA_126 package directory. """ - return files("pyrenew.datasets") / "synthetic_CA_120" + return files("pyrenew.datasets") / "synthetic_CA_126" def load_synthetic_true_parameters() -> dict: diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 8a124348..768eee25 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -225,7 +225,17 @@ def validate_data( f"Available: {list(self.observations.keys())}" ) - self.observations[name].validate_data( + obs = self.observations[name] + if getattr(obs, "aggregation_period", 1) > 1 and ( + "first_day_dow" not in obs_data + ): + raise ValueError( + f"Observation '{name}' has aggregation_period=" + f"{obs.aggregation_period} and requires 'first_day_dow' " + f"in its observation data." + ) + + obs.validate_data( n_total=n_total, n_subpops=pop.n_subpops, **obs_data, diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index c8efa4f8..1258d932 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -10,6 +10,7 @@ from typing import Any from pyrenew.latent.base import BaseLatentInfectionProcess +from pyrenew.latent.temporal_processes import TemporalProcess from pyrenew.model.multisignal_model import MultiSignalModel from pyrenew.observation.base import BaseObservationProcess @@ -195,15 +196,84 @@ def compute_n_initialization_points(self) -> int: return n_init + def _validate_coherence(self) -> None: + """ + Enforce end-to-end coherence between R(t) cadence and observation cadences. + + Called at the start of ``build()`` before any model components are + constructed. Inspects ``self.observations`` for ``aggregation_period`` + and ``period_end_dow`` attributes, and walks ``self.latent_params`` + values for any ``TemporalProcess`` instances to read their + ``step_size`` attribute. + + Checks: + + - All observations sharing the same ``aggregation_period > 1`` must + agree on ``period_end_dow``. + - Every temporal-process ``step_size`` must be ``<=`` the finest + observation ``aggregation_period``. + - If a temporal process has ``step_size > 1``, that ``step_size`` + must equal the ``aggregation_period`` of every observation whose + ``aggregation_period > 1``. + + Raises + ------ + ValueError + If any of the three rules is violated. + """ + agg_by_dow: dict[int, set[int]] = {} + agg_periods: list[int] = [] + for name, obs in self.observations.items(): + P = getattr(obs, "aggregation_period", 1) + agg_periods.append(P) + if P > 1: + agg_by_dow.setdefault(P, set()).add(getattr(obs, "period_end_dow")) + + for P, dows in agg_by_dow.items(): + if len(dows) > 1: + raise ValueError( + f"Observations with aggregation_period={P} must agree on " + f"period_end_dow; got values {sorted(dows)}" + ) + + temporal_processes = { + name: value + for name, value in self.latent_params.items() + if isinstance(value, TemporalProcess) + } + + finest = min(agg_periods) if agg_periods else 1 + coarse_periods = {P for P in agg_periods if P > 1} + + for param_name, process in temporal_processes.items(): + step_size = getattr(process, "step_size", 1) + if step_size > finest: + raise ValueError( + f"Temporal process '{param_name}' has step_size={step_size} " + f"exceeding the finest observation aggregation_period " + f"({finest}). Parameterize R(t) at least as finely as the " + f"finest observed signal." + ) + if step_size > 1: + mismatched = {P for P in coarse_periods if P != step_size} + if mismatched: + raise ValueError( + f"Temporal process '{param_name}' has step_size=" + f"{step_size} but observations include " + f"aggregation_period(s) {sorted(mismatched)} that do " + f"not match. All coarse cadences must agree." + ) + def build(self) -> MultiSignalModel: """ Build the multi-signal model with computed n_initialization_points. This method: - 1. Computes n_initialization_points from all components - 2. Constructs the latent process with the computed value - 3. Creates a MultiSignalModel with automatic infection routing - 4. Validates that observation/latent types are compatible + 1. Enforces coherence between R(t) cadence and observation cadences + 2. Computes n_initialization_points from all components + 3. Constructs the latent process with the computed value + 4. Creates a MultiSignalModel with automatic infection routing + 5. Validates that observation/latent types are compatible Can be called multiple times to create multiple model instances. @@ -215,11 +285,14 @@ def build(self) -> MultiSignalModel: Raises ------ ValueError - If latent process not configured + If latent process not configured, or if R(t) and observation + cadences are incoherent. """ if self.latent_class is None: raise ValueError("Must call configure_latent() before build()") + self._validate_coherence() + # Compute n_initialization_points n_init = self.compute_n_initialization_points() diff --git a/test/conftest.py b/test/conftest.py index 27f54f8f..bf282e2f 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,8 +3,22 @@ This module provides reusable fixtures for creating observation processes, test data, and common configurations used across multiple test files. + +The module also sets two JAX environment variables before any jax import so +that all tests (unit and integration) run with 64-bit precision and with +four logical host devices available for parallel MCMC chains. JAX reads +these variables at import time, so they must be set before the first +``import jax`` anywhere in the test process; placing them at the top of +this file (loaded by pytest before any test module) satisfies that +requirement. The ``setdefault`` form respects any value the caller already +set at the shell level (e.g., a CI with different configuration). """ +import os + +os.environ.setdefault("JAX_ENABLE_X64", "true") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=4") + import jax.numpy as jnp import numpyro.distributions as dist import pytest diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 08b04b81..f954fbb7 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -18,6 +18,7 @@ load_synthetic_daily_hospital_admissions, load_synthetic_daily_infections, load_synthetic_true_parameters, + load_synthetic_weekly_hospital_admissions, ) from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import AR1 @@ -80,6 +81,19 @@ def daily_ed() -> pl.DataFrame: return load_synthetic_daily_ed_visits() +@pytest.fixture(scope="module") +def weekly_hosp() -> pl.DataFrame: + """ + Load synthetic weekly (MMWR epiweek) hospital admissions. + + Returns + ------- + pl.DataFrame + Columns: week_end, weekly_hosp_admits, location, pop. + """ + return load_synthetic_weekly_hospital_admissions() + + @pytest.fixture(scope="module") def hosp_delay_pmf() -> jnp.ndarray: """ @@ -188,3 +202,71 @@ def he_model( builder.add_observation(ed_obs) return builder.build() + + +@pytest.fixture(scope="module") +def he_weekly_model( + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> PyrenewBuilder: + """ + Build a PopulationInfections model with WEEKLY hospital + DAILY ED observations. + + The hospital observation is aggregated to MMWR epiweeks (Sunday-Saturday, + anchored by ``period_end_dow=5``); the ED observation stays daily with a + day-of-week effect. R(t) is parametrized at the finest observation + cadence (daily) per the coherence rules for mixed-cadence models. + + Parameters + ---------- + hosp_delay_pmf : jnp.ndarray + Infection-to-hospitalization delay PMF. + ed_delay_pmf : jnp.ndarray + Infection-to-ED-visit delay PMF. + ed_day_of_week_effects : jnp.ndarray + Day-of-week multipliers used in synthetic ED generation. + + Returns + ------- + MultiSignalModel + Built model ready for fitting. + """ + gen_int_pmf = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] + ) + + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + ) + + hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=5, + ) + builder.add_observation(hospital_obs) + + ed_obs = PopulationCounts( + name="ed", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + builder.add_observation(ed_obs) + + return builder.build() diff --git a/test/integration/test_population_infections_he.py b/test/integration/test_population_infections_he.py index 12ac6553..03f5c4a3 100644 --- a/test/integration/test_population_infections_he.py +++ b/test/integration/test_population_infections_he.py @@ -2,7 +2,7 @@ Integration test: PopulationInfections H+E model with posterior recovery. Fits a PopulationInfections model with hospital admissions and ED visit -observation processes to synthetic 120-day CA data, then checks that +observation processes to synthetic 126-day CA data, then checks that posterior estimates recover known true parameters. """ @@ -21,7 +21,7 @@ pytestmark = pytest.mark.integration -N_DAYS_FIT = 120 +N_DAYS_FIT = 126 NUM_WARMUP = 500 NUM_SAMPLES = 500 NUM_CHAINS = 4 @@ -48,9 +48,9 @@ def test_data_shapes( daily_infections : pl.DataFrame True infections and R(t). """ - assert len(daily_hosp) == 120 - assert len(daily_ed) == 120 - assert len(daily_infections) == 120 + assert len(daily_hosp) == 126 + assert len(daily_ed) == 126 + assert len(daily_infections) == 126 assert "daily_hosp_admits" in daily_hosp.columns assert "ed_visits" in daily_ed.columns assert "true_rt" in daily_infections.columns diff --git a/test/test_ar_process.py b/test/test_ar_process.py index 2af800dd..5bef827d 100755 --- a/test/test_ar_process.py +++ b/test/test_ar_process.py @@ -243,6 +243,23 @@ def test_ar_process_asymptotics(ar_inits, autoreg, noise_sd, n): :n ] + # Closed-form stationary standard deviation for the AR(p) process. + # The previous assertion used `3 * noise_sd`, which is only correct when + # the stationary variance equals the innovation variance (p=0); for AR(1) + # and AR(2) with non-zero autoreg, the stationary SD is strictly larger. + sigma2 = noise_sd**2 + if order == 1: + stationary_sd = jnp.sqrt(sigma2 / (1 - autoreg[0] ** 2)) + elif order == 2: + phi1, phi2 = autoreg[0], autoreg[1] + stationary_sd = jnp.sqrt( + sigma2 * (1 - phi2) / ((1 + phi2) * ((1 - phi2) ** 2 - phi1**2)) + ) + else: + raise NotImplementedError( + f"Stationary SD for AR order {order} not implemented in this test." + ) + with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it @@ -255,4 +272,4 @@ def test_ar_process_asymptotics(ar_inits, autoreg, noise_sd, n): ) assert_array_almost_equal(long_ts[:order], expected_first_entries) - assert jnp.abs(long_ts[-1]) < 3 * noise_sd + assert jnp.abs(long_ts[-1]) < 3 * stationary_sd diff --git a/test/test_datagen_he_CA_120.py b/test/test_datagen_he_CA_126.py similarity index 74% rename from test/test_datagen_he_CA_120.py rename to test/test_datagen_he_CA_126.py index 87aef8e9..ebbfe997 100644 --- a/test/test_datagen_he_CA_120.py +++ b/test/test_datagen_he_CA_126.py @@ -1,5 +1,5 @@ """ -Unit tests for the datagen_he_CA_120 helper functions. +Unit tests for the datagen_he_CA_126 helper functions. """ import json @@ -8,11 +8,11 @@ import numpy as np import polars as pl -import pyrenew.datasets.datagen_he_CA_120 as datagen_mod -from pyrenew.datasets.datagen_he_CA_120 import ( - aggregate_to_epiweeks, +import pyrenew.datasets.datagen_he_CA_126 as datagen_mod +from pyrenew.datasets.datagen_he_CA_126 import ( apply_day_of_week_effects, build_true_rt, + build_weekly_hosp_frame, generate, run_renewal, sample_negbinom, @@ -23,9 +23,9 @@ class TestBuildTrueRt: """Tests for build_true_rt.""" def test_length(self): - """Test that the output has 120 days.""" + """Test that the output has 126 days.""" rt = build_true_rt() - assert len(rt) == 120 + assert len(rt) == 126 def test_starting_value(self): """Test that Rt starts near 1.2.""" @@ -35,8 +35,9 @@ def test_starting_value(self): def test_phase_endpoints(self): """Test that phase transitions occur at expected values.""" rt = build_true_rt() - assert rt[59] < 0.85 - assert rt[60] > 0.79 + # Phase 1 ends at index 62 (last day of 63-day decline), phase 2 starts at 63. + assert rt[62] < 0.85 + assert rt[63] > 0.79 class TestRunRenewal: @@ -116,25 +117,39 @@ def test_handles_near_zero_mu(self): assert result.shape == mu.shape -class TestAggregateToEpiweeks: - """Tests for aggregate_to_epiweeks.""" +class TestBuildWeeklyHospFrame: + """Tests for build_weekly_hosp_frame.""" - def test_complete_weeks_only(self): - """Test that only complete 7-day weeks are returned.""" - daily = np.ones(20) + def test_length_matches_input(self): + """Frame length equals the number of weekly values supplied.""" + weekly = np.array([100.0, 200.0, 300.0]) start = date(2023, 11, 6) - result = aggregate_to_epiweeks(daily, start) + result = build_weekly_hosp_frame(weekly, start) assert isinstance(result, pl.DataFrame) - assert "weekly_hosp_admits" in result.columns - for val in result["weekly_hosp_admits"].to_list(): - assert val == 7.0 + assert len(result) == 3 - def test_weekly_sum(self): - """Test that weekly sums are correct.""" - daily = np.arange(1, 15, dtype=float) + def test_weekly_values_preserved(self): + """Input weekly values appear in the weekly_hosp_admits column.""" + weekly = np.array([100.0, 200.0, 300.0]) + start = date(2023, 11, 6) + result = build_weekly_hosp_frame(weekly, start) + assert result["weekly_hosp_admits"].to_list() == [100.0, 200.0, 300.0] + + def test_first_week_end_is_saturday(self): + """First week_end is the Saturday of the first complete MMWR epiweek.""" + # 2023-11-06 is a Monday; first MMWR Saturday after (Sun-start week) is 2023-11-18. + weekly = np.array([1.0, 2.0]) + start = date(2023, 11, 6) + result = build_weekly_hosp_frame(weekly, start) + assert result["week_end"].to_list() == [date(2023, 11, 18), date(2023, 11, 25)] + + def test_first_week_end_sunday_start(self): + """A Sunday start yields the immediately-following Saturday as first week_end.""" + # 2023-11-05 is a Sunday; first full MMWR week ends Saturday 2023-11-11. + weekly = np.array([5.0]) start = date(2023, 11, 5) - result = aggregate_to_epiweeks(daily, start) - assert len(result) >= 1 + result = build_weekly_hosp_frame(weekly, start) + assert result["week_end"].to_list() == [date(2023, 11, 11)] class TestGenerate: @@ -164,7 +179,7 @@ def test_generate_true_parameters_content(self, tmp_path, monkeypatch): params = json.load(f) assert params["population"] == datagen_mod.POPULATION - assert params["n_days"] == 120 + assert params["n_days"] == 126 assert "hospitalizations" in params assert "ed_visits" in params @@ -174,7 +189,7 @@ def test_generate_daily_infections_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_infections.csv") - assert len(df) == 120 + assert len(df) == 126 assert "true_infections" in df.columns assert "true_rt" in df.columns @@ -184,7 +199,7 @@ def test_generate_hospital_admissions_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_hospital_admissions.csv") - assert len(df) == 120 + assert len(df) == 126 assert "daily_hosp_admits" in df.columns def test_generate_ed_visits_shape(self, tmp_path, monkeypatch): @@ -193,5 +208,5 @@ def test_generate_ed_visits_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_ed_visits.csv") - assert len(df) == 120 + assert len(df) == 126 assert "ed_visits" in df.columns diff --git a/test/test_datasets_synthetic.py b/test/test_datasets_synthetic.py index 52872d78..ca41e5c8 100644 --- a/test/test_datasets_synthetic.py +++ b/test/test_datasets_synthetic.py @@ -1,5 +1,5 @@ """ -Unit tests for synthetic CA 120-day dataset loaders. +Unit tests for synthetic CA 126-day dataset loaders. """ import polars as pl @@ -46,9 +46,9 @@ def test_population_is_positive(self): assert params["population"] > 0 def test_n_days_matches_data(self): - """Test that n_days is 120.""" + """Test that n_days is 126.""" params = load_synthetic_true_parameters() - assert params["n_days"] == 120 + assert params["n_days"] == 126 def test_hospitalization_params(self): """Test that hospitalization sub-dict has expected keys.""" @@ -83,10 +83,10 @@ def test_has_expected_columns(self): assert "true_infections" in df.columns assert "true_rt" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_infections() - assert len(df) == 120 + assert len(df) == 126 def test_infections_are_positive(self): """Test that true infections are positive.""" @@ -114,10 +114,10 @@ def test_has_expected_columns(self): assert "daily_hosp_admits" in df.columns assert "pop" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_hospital_admissions() - assert len(df) == 120 + assert len(df) == 126 def test_admits_are_non_negative(self): """Test that hospital admissions are non-negative.""" @@ -139,10 +139,10 @@ def test_has_expected_columns(self): assert "date" in df.columns assert "ed_visits" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_ed_visits() - assert len(df) == 120 + assert len(df) == 126 def test_visits_are_non_negative(self): """Test that ED visits are non-negative.""" diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 04f5da27..19497048 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -6,10 +6,17 @@ import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import RandomWalk, SubpopulationInfections +from pyrenew.latent import ( + AR1, + PopulationInfections, + RandomWalk, + StepwiseTemporalProcess, + SubpopulationInfections, +) from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import ( NegativeBinomialNoise, + PoissonNoise, PopulationCounts, SubpopulationCounts, ) @@ -443,5 +450,183 @@ def test_shift_times_adds_offset(self, simple_builder): assert jnp.array_equal(shifted, times + n_init) +def _coherence_builder( + *, + single_rt_process, + observations, +): + """ + Build a configured PyrenewBuilder with a PopulationInfections latent and + a supplied temporal process, plus the given observation instances. + + Returns + ------- + PyrenewBuilder + Builder configured but not yet built. + """ + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=single_rt_process, + ) + for obs in observations: + builder.add_observation(obs) + return builder + + +def _weekly_hosp_counts(name="hospital", period_end_dow=5): + """ + Build a weekly-aggregated PopulationCounts observation with PoissonNoise. + + Returns + ------- + PopulationCounts + Weekly-regular observation anchored to the specified period end dow. + """ + return PopulationCounts( + name=name, + ascertainment_rate_rv=DeterministicVariable(f"{name}_ihr", 0.01), + delay_distribution_rv=DeterministicPMF(f"{name}_delay", jnp.array([1.0])), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=period_end_dow, + ) + + +def _daily_ed_counts(name="ed"): + """ + Build a daily PopulationCounts observation with PoissonNoise. + + Returns + ------- + PopulationCounts + Daily-regular observation with no aggregation. + """ + return PopulationCounts( + name=name, + ascertainment_rate_rv=DeterministicVariable(f"{name}_ihr", 0.01), + delay_distribution_rv=DeterministicPMF(f"{name}_delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) + + +class TestBuilderCoherence: + """PyrenewBuilder._validate_coherence enforcement at build() time.""" + + def test_daily_rt_with_daily_observation_passes(self): + """step_size=1 and P=1: valid.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_weekly_rt_with_weekly_observation_passes(self): + """step_size=7 and P=7: valid when weekly is the only obs cadence.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_daily_rt_with_mixed_observations_passes(self): + """step_size=1 with mixed P=1 + P=7: valid (R(t) at finest cadence).""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_weekly_rt_with_daily_observation_raises(self): + """step_size=7 with any daily obs: rule 2 violation.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + with pytest.raises(ValueError, match="exceeding the finest"): + builder.build() + + def test_mismatched_weekly_period_end_dow_raises(self): + """Two weekly observations with different period_end_dow: rule 1 violation.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[ + _weekly_hosp_counts(name="hospital", period_end_dow=5), + _weekly_hosp_counts(name="other", period_end_dow=6), + ], + ) + with pytest.raises(ValueError, match="must agree on period_end_dow"): + builder.build() + + def test_matching_weekly_period_end_dow_passes(self): + """Two weekly observations agreeing on period_end_dow: rule 1 passes.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[ + _weekly_hosp_counts(name="hospital", period_end_dow=5), + _weekly_hosp_counts(name="other", period_end_dow=5), + ], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + +class TestMultiSignalValidateDataAnchor: + """MultiSignalModel.validate_data sample-time anchor check for first_day_dow.""" + + def test_missing_first_day_dow_for_weekly_obs_raises(self): + """An observation with aggregation_period>1 must have first_day_dow supplied.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + with pytest.raises(ValueError, match="requires 'first_day_dow'"): + model.validate_data( + n_days_post_init=28, + hospital={"obs": jnp.ones(4) * 5.0}, + ) + + def test_first_day_dow_supplied_for_weekly_obs_passes(self): + """Supplying first_day_dow satisfies the anchor check.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + model.validate_data( + n_days_post_init=28, + hospital={"obs": jnp.ones(4) * 5.0, "first_day_dow": 6}, + ) + + def test_anchor_check_skipped_for_daily_obs(self): + """Daily observations do not require first_day_dow at validate_data time.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_daily_ed_counts()], + ) + model = builder.build() + n_total = model.latent.n_initialization_points + 30 + model.validate_data( + n_days_post_init=30, + ed={"obs": jnp.ones(n_total) * 5.0}, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From ae61442f1188dfd1fdef1397faf1917a6e77dbae Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Tue, 21 Apr 2026 19:42:59 -0400 Subject: [PATCH 04/17] integration test --- .../test_population_infections_he_weekly.py | 460 ++++++++++++++++++ 1 file changed, 460 insertions(+) create mode 100644 test/integration/test_population_infections_he_weekly.py diff --git a/test/integration/test_population_infections_he_weekly.py b/test/integration/test_population_infections_he_weekly.py new file mode 100644 index 00000000..6c1644e9 --- /dev/null +++ b/test/integration/test_population_infections_he_weekly.py @@ -0,0 +1,460 @@ +""" +Integration test: PopulationInfections H+E model with WEEKLY hospital admissions. + +Same structure as ``test_population_infections_he`` but the hospital +signal is aggregated to MMWR epiweeks. Fits the mixed-cadence model +(weekly hospital + daily ED) to synthetic 126-day CA data and checks +posterior recovery. +""" + +from __future__ import annotations + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +# Day 0 of the synthetic data is 2023-11-05, a Sunday (ISO dow = 6). +OBS_START_DOW = 6 + + +def _build_hospital_obs_on_period_grid( + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """ + Build a dense weekly-observation array on the model's period grid. + + Pads ``n_pre`` NaN values at the front for periods that precede the + first observed week (periods that overlap the initialization window + and any pre-observation gap). + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + weekly_values : jnp.ndarray + Observed weekly hospital admissions, one per MMWR epiweek. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + jnp.ndarray + Dense array of shape ``(n_periods,)`` with NaN for unobserved + periods and observed counts for periods covered by + ``weekly_values``. + """ + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = (hosp.period_end_dow + 1 - first_day_dow) % hosp.aggregation_period + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate( + [jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values] + ) + + +class TestDataAssembly: + """Verify synthetic data can be loaded and aligned to the weekly model.""" + + def test_data_shapes( + self, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + daily_infections: pl.DataFrame, + ) -> None: + """ + Verify synthetic data files have expected row counts and columns. + + Parameters + ---------- + weekly_hosp : pl.DataFrame + Weekly MMWR-epiweek hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + daily_infections : pl.DataFrame + True infections and R(t). + """ + assert len(weekly_hosp) == 18 + assert len(daily_ed) == 126 + assert len(daily_infections) == 126 + assert "weekly_hosp_admits" in weekly_hosp.columns + assert "ed_visits" in daily_ed.columns + assert "true_rt" in daily_infections.columns + + def test_hospital_is_weekly_regular( + self, + he_weekly_model: MultiSignalModel, + ) -> None: + """ + Verify the hospital observation is weekly-aggregated and MMWR-anchored. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + """ + h = he_weekly_model.observations["hospital"] + assert h.aggregation_period == 7 + assert h.reporting_schedule == "regular" + assert h.period_end_dow == 5 + + def test_ed_stays_daily( + self, + he_weekly_model: MultiSignalModel, + ) -> None: + """ + Verify the ED observation remains daily with a day-of-week effect. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + """ + ed = he_weekly_model.observations["ed"] + assert ed.aggregation_period == 1 + assert ed.day_of_week_rv is not None + + def test_weekly_obs_alignment( + self, + he_weekly_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Verify the weekly obs array has the correct dense-on-period-grid shape. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + first_day_dow = he_weekly_model.compute_first_day_dow(OBS_START_DOW) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_model, weekly_values, first_day_dow + ) + + n_init = he_weekly_model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + hosp = he_weekly_model.observations["hospital"] + offset = ( + hosp.period_end_dow + 1 - first_day_dow + ) % hosp.aggregation_period + n_periods = (n_total - offset) // hosp.aggregation_period + + assert hosp_obs.shape == (n_periods,) + assert jnp.isnan(hosp_obs[0]) + assert not jnp.isnan(hosp_obs[-1]) + assert int((~jnp.isnan(hosp_obs)).sum()) == len(weekly_hosp) + + +class TestModelFit: + """Fit the weekly H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( + self, + he_weekly_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """ + Fit the mixed-cadence H+E model to synthetic data via MCMC. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + + Returns + ------- + MultiSignalModel + Model with MCMC results attached. + """ + first_day_dow = he_weekly_model.compute_first_day_dow(OBS_START_DOW) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_model, weekly_values, first_day_dow + ) + + ed_obs = he_weekly_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(weekly_hosp["pop"][0]) + + he_weekly_model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, + ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + ) + + samples = he_weekly_model.mcmc.get_samples() + jax.block_until_ready(samples) + return he_weekly_model + + @pytest.fixture(scope="class") + def posterior_dt( + self, + fitted_model: MultiSignalModel, + ): + """ + Convert MCMC samples to an ArviZ DataTree, trimming the init period. + + The hospital signal lives on the weekly period grid (dim ``week``); + the daily-scale sites (``time``) are trimmed by ``n_init``. + + Parameters + ---------- + fitted_model : MultiSignalModel + Model with MCMC results. + + Returns + ------- + xarray.DataTree + ArviZ DataTree with posterior group. + """ + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): + """ + Trim the initialization period from the ``time`` dimension only. + + Parameters + ---------- + ds + Dataset to trim. + + Returns + ------- + xarray.Dataset + Dataset with ``time`` sliced to ``[n_init:]``; ``week`` + and other dims pass through unchanged. + """ + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """ + Check that core parameters have acceptable Rhat and ESS. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """ + Check that the 90% credible interval for R(t) covers the true value + for at least 80% of time points. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_infections : pl.DataFrame + True R(t) trajectory. + """ + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + assert coverage >= 0.80, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 80%" + ) + + def test_infection_trajectory_shape( + self, + posterior_dt, + ) -> None: + """ + Check the posterior infection trajectory has correct shape and is positive. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + infections = posterior_dt.posterior["latent_infections"] + assert infections.sizes["time"] == N_DAYS_FIT + assert (infections.values > 0).all() + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """ + Check that posterior median IHR and IEDR are within a factor + of 5 of the true values. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + true_params : dict + Ground-truth parameter dictionary. + """ + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + def test_hospital_predicted_weekly_grid( + self, + posterior_dt, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Check that hospital posterior predictions live on the weekly grid + and have plausible magnitude relative to observed weekly counts. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + hosp_pred = posterior_dt.posterior["hospital_predicted"] + # n_periods from the weekly grid must be >= the number of observed weeks. + assert hosp_pred.sizes["week"] >= len(weekly_hosp) + + hosp_pred_median = float( + hosp_pred.median(dim=["chain", "draw", "week"]).values + ) + hosp_obs_mean = float(weekly_hosp["weekly_hosp_admits"].mean()) + + assert hosp_pred_median > 0, "Hospital predictions should be positive" + assert hosp_obs_mean / 10 <= hosp_pred_median <= hosp_obs_mean * 10 + + def test_hospital_predicted_daily_has_daily_grid( + self, + posterior_dt, + ) -> None: + """ + Check that the ``hospital_predicted_daily`` deterministic site covers + the daily shared axis. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + hosp_pred_daily = posterior_dt.posterior["hospital_predicted_daily"] + assert hosp_pred_daily.sizes["time"] == N_DAYS_FIT + + def test_ed_predicted_reasonable( + self, + posterior_dt, + daily_ed: pl.DataFrame, + ) -> None: + """ + Check that daily ED posterior predictions have the right shape and + plausible magnitude relative to observed daily ED visits. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_ed : pl.DataFrame + Daily ED visits. + """ + ed_pred = posterior_dt.posterior["ed_predicted"] + assert ed_pred.sizes["time"] == N_DAYS_FIT + + ed_pred_median = float(ed_pred.median(dim=["chain", "draw", "time"]).values) + ed_obs_mean = float(daily_ed["ed_visits"].mean()) + + assert ed_pred_median > 0, "ED predictions should be positive" + assert ed_obs_mean / 10 <= ed_pred_median <= ed_obs_mean * 10 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) From 613c325e6d1f7a313d91e12434cd5f85c6662642 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 23:48:03 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_population_infections_he_weekly.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test/integration/test_population_infections_he_weekly.py b/test/integration/test_population_infections_he_weekly.py index 6c1644e9..9872059d 100644 --- a/test/integration/test_population_infections_he_weekly.py +++ b/test/integration/test_population_infections_he_weekly.py @@ -64,9 +64,7 @@ def _build_hospital_obs_on_period_grid( offset = (hosp.period_end_dow + 1 - first_day_dow) % hosp.aggregation_period n_periods = (n_total - offset) // hosp.aggregation_period n_pre = n_periods - len(weekly_values) - return jnp.concatenate( - [jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values] - ) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) class TestDataAssembly: @@ -156,9 +154,7 @@ def test_weekly_obs_alignment( n_init = he_weekly_model.latent.n_initialization_points n_total = n_init + N_DAYS_FIT hosp = he_weekly_model.observations["hospital"] - offset = ( - hosp.period_end_dow + 1 - first_day_dow - ) % hosp.aggregation_period + offset = (hosp.period_end_dow + 1 - first_day_dow) % hosp.aggregation_period n_periods = (n_total - offset) // hosp.aggregation_period assert hosp_obs.shape == (n_periods,) @@ -406,9 +402,7 @@ def test_hospital_predicted_weekly_grid( # n_periods from the weekly grid must be >= the number of observed weeks. assert hosp_pred.sizes["week"] >= len(weekly_hosp) - hosp_pred_median = float( - hosp_pred.median(dim=["chain", "draw", "week"]).values - ) + hosp_pred_median = float(hosp_pred.median(dim=["chain", "draw", "week"]).values) hosp_obs_mean = float(weekly_hosp["weekly_hosp_admits"].mean()) assert hosp_pred_median > 0, "Hospital predictions should be positive" From eb3bc70ded19d5967d581a1622a2a9464a5165b7 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 22 Apr 2026 11:22:31 -0400 Subject: [PATCH 06/17] changes per co-pilot code review --- pyproject.toml | 5 + pyrenew/observation/base.py | 60 ++++++--- pyrenew/observation/count_observations.py | 19 ++- .../observation/measurement_observations.py | 2 +- test/integration/conftest.py | 6 +- test/test_observation_validation.py | 119 ++++++++++++++---- 6 files changed, 163 insertions(+), 48 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f84cf70..e4465682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,11 @@ known_first_party = ["pyrenew", "test"] [tool.deptry.per_rule_ignores] DEP004 = ["arviz", "pytest", "scipy", "bs4"] +[tool.pytest.ini_options] +markers = [ + "integration: integration tests that fit models via MCMC (deselect with '-m \"not integration\"')", +] + [tool.ruff] fix = true diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index c1906aea..c7bc6e3b 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -467,10 +467,11 @@ def _validate_index_array( self, indices: ArrayLike, upper_bound: int, param_name: str ) -> None: """ - Validate an index array has non-negative values within bounds. + Validate an index array is 1D with non-negative values within bounds. - Checks that all values are non-negative integers in ``[0, upper_bound)``. - An empty array is a no-op and passes validation. + Checks that the array is 1D and that all values are non-negative + integers in ``[0, upper_bound)``. An empty 1D array is a no-op + and passes the bounds check. Parameters ---------- @@ -484,9 +485,15 @@ def _validate_index_array( Raises ------ ValueError - If indices contains negative values or values >= upper_bound. + If indices is not 1D, contains negative values, or values + >= upper_bound. """ indices = jnp.asarray(indices) + if indices.ndim != 1: + raise ValueError( + f"Observation '{self.name}': {param_name} must be 1D, " + f"got shape {indices.shape}" + ) if indices.size == 0: return if jnp.any(indices < 0): @@ -542,28 +549,38 @@ def _validate_subpop_indices( """ self._validate_index_array(subpop_indices, n_subpops, "subpop_indices") - def _validate_obs_times_shape(self, obs: ArrayLike, times: ArrayLike) -> None: + def _validate_shapes_match( + self, + first: ArrayLike, + second: ArrayLike, + first_name: str, + second_name: str, + ) -> None: """ - Validate that obs and times arrays have matching shapes. + Validate that two arrays have matching shapes. Parameters ---------- - obs - Observed data array. - times - Times index array. + first + First array. + second + Second array. + first_name + Name of the first parameter (for error messages). + second_name + Name of the second parameter (for error messages). Raises ------ ValueError - If obs and times have different shapes. + If the two arrays have different shapes. """ - obs = jnp.asarray(obs) - times = jnp.asarray(times) - if obs.shape != times.shape: + first = jnp.asarray(first) + second = jnp.asarray(second) + if first.shape != second.shape: raise ValueError( - f"Observation '{self.name}': obs shape {obs.shape} " - f"must match times shape {times.shape}" + f"Observation '{self.name}': {first_name} shape {first.shape} " + f"must match {second_name} shape {second.shape}" ) def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: @@ -571,8 +588,9 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: Validate that obs covers the full shared time axis. For dense observations on the shared time axis ``[0, n_total)``, - obs must have length equal to ``n_total``. Use NaN to mark - unobserved timepoints (initialization period or missing data). + obs must be 1D with length equal to ``n_total``. Use NaN to + mark unobserved timepoints (initialization period or missing + data). Parameters ---------- @@ -584,9 +602,13 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: Raises ------ ValueError - If obs length doesn't equal n_total. + If obs is not 1D or its length doesn't equal n_total. """ obs = jnp.asarray(obs) + if obs.ndim != 1: + raise ValueError( + f"Observation '{self.name}': obs must be 1D, got shape {obs.shape}" + ) if obs.shape[0] != n_total: raise ValueError( f"Observation '{self.name}': obs length {obs.shape[0]} " diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 9519942e..3a7ea06b 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -429,6 +429,10 @@ def validate_data( offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) n_periods = (n_total - offset) // P obs = jnp.asarray(obs) + if obs.ndim != 1: + raise ValueError( + f"Observation '{self.name}': obs must be 1D, got shape {obs.shape}" + ) if obs.shape[0] != n_periods: raise ValueError( f"Observation '{self.name}': obs length {obs.shape[0]} " @@ -447,7 +451,9 @@ def validate_data( offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) self._validate_period_end_times(period_end_times, n_total, offset, P) if obs is not None: - self._validate_obs_times_shape(obs, period_end_times) + self._validate_shapes_match( + obs, period_end_times, "obs", "period_end_times" + ) def sample( self, @@ -699,9 +705,16 @@ def validate_data( offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) self._validate_period_end_times(period_end_times, n_total, offset, P) if obs is not None: - self._validate_obs_times_shape(obs, period_end_times) + self._validate_shapes_match( + obs, period_end_times, "obs", "period_end_times" + ) if subpop_indices is not None: - self._validate_obs_times_shape(subpop_indices, period_end_times) + self._validate_shapes_match( + subpop_indices, + period_end_times, + "subpop_indices", + "period_end_times", + ) def sample( self, diff --git a/pyrenew/observation/measurement_observations.py b/pyrenew/observation/measurement_observations.py index e5c951a7..137e8ce0 100644 --- a/pyrenew/observation/measurement_observations.py +++ b/pyrenew/observation/measurement_observations.py @@ -146,7 +146,7 @@ def validate_data( if times is not None: self._validate_times(times, n_total) if obs is not None: - self._validate_obs_times_shape(obs, times) + self._validate_shapes_match(obs, times, "obs", "times") if subpop_indices is not None: self._validate_subpop_indices(subpop_indices, n_subpops) if sensor_indices is not None and n_sensors is not None: diff --git a/test/integration/conftest.py b/test/integration/conftest.py index f954fbb7..2ba38544 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -23,7 +23,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import AR1 from pyrenew.latent.population_infections import PopulationInfections -from pyrenew.model import PyrenewBuilder +from pyrenew.model import PyrenewBuilder, MultiSignalModel from pyrenew.observation import NegativeBinomialNoise, PopulationCounts from pyrenew.randomvariable import DistributionalVariable @@ -149,7 +149,7 @@ def he_model( hosp_delay_pmf: jnp.ndarray, ed_delay_pmf: jnp.ndarray, ed_day_of_week_effects: jnp.ndarray, -) -> PyrenewBuilder: +) -> MultiSignalModel: """ Build a PopulationInfections model with hospital + ED observation processes. @@ -209,7 +209,7 @@ def he_weekly_model( hosp_delay_pmf: jnp.ndarray, ed_delay_pmf: jnp.ndarray, ed_day_of_week_effects: jnp.ndarray, -) -> PyrenewBuilder: +) -> MultiSignalModel: """ Build a PopulationInfections model with WEEKLY hospital + DAILY ED observations. diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index 16ca3db4..cbd47906 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -198,6 +198,24 @@ def test_empty_array_passes(self, counts_proc): param_name="test_idx", ) + def test_scalar_raises(self, counts_proc): + """0-D scalar index array should raise ValueError.""" + with pytest.raises(ValueError, match="must be 1D"): + counts_proc._validate_index_array( + jnp.asarray(0), + upper_bound=5, + param_name="test_idx", + ) + + def test_2d_raises(self, counts_proc): + """2D index array should raise ValueError.""" + with pytest.raises(ValueError, match="must be 1D"): + counts_proc._validate_index_array( + jnp.array([[0, 1], [2, 3]]), + upper_bound=5, + param_name="test_idx", + ) + # =================================================================== # _validate_times @@ -260,44 +278,56 @@ def test_subpop_index_above_n_subpops_raises(self, counts_proc): # =================================================================== -# _validate_obs_times_shape +# _validate_shapes_match # =================================================================== -class TestValidateObsTimesShape: - """Tests for BaseObservationProcess._validate_obs_times_shape.""" +class TestValidateShapesMatch: + """Tests for BaseObservationProcess._validate_shapes_match.""" def test_matching_1d_shapes(self, counts_proc): - """Obs and times with matching 1D shapes should not raise.""" - obs = jnp.array([1.0, 2.0, 3.0]) - times = jnp.array([0, 5, 10]) - counts_proc._validate_obs_times_shape(obs, times) + """Two arrays with matching 1D shapes should not raise.""" + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([0, 5, 10]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_mismatched_lengths_raises(self, counts_proc): - """Obs and times with different lengths should raise ValueError.""" - obs = jnp.array([1.0, 2.0, 3.0]) - times = jnp.array([0, 5]) + """Two arrays with different lengths should raise ValueError.""" + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([0, 5]) with pytest.raises(ValueError, match="must match times shape"): - counts_proc._validate_obs_times_shape(obs, times) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_empty_arrays_match(self, counts_proc): """Two empty arrays should not raise.""" - obs = jnp.array([]) - times = jnp.array([]) - counts_proc._validate_obs_times_shape(obs, times) + a = jnp.array([]) + b = jnp.array([]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_scalar_arrays_match(self, counts_proc): """Single-element arrays should not raise.""" - obs = jnp.array([5.0]) - times = jnp.array([0]) - counts_proc._validate_obs_times_shape(obs, times) + a = jnp.array([5.0]) + b = jnp.array([0]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_error_includes_both_shapes(self, counts_proc): """Error message should report both shapes.""" - obs = jnp.array([1.0, 2.0]) - times = jnp.array([0, 1, 2]) + a = jnp.array([1.0, 2.0]) + b = jnp.array([0, 1, 2]) with pytest.raises(ValueError, match=r"\(2,\).*\(3,\)"): - counts_proc._validate_obs_times_shape(obs, times) + counts_proc._validate_shapes_match(a, b, "obs", "times") + + def test_error_uses_provided_names(self, counts_proc): + """Error message should use the caller-supplied parameter names.""" + a = jnp.array([1.0, 2.0]) + b = jnp.array([0, 1, 2]) + with pytest.raises( + ValueError, + match=r"subpop_indices shape .* must match period_end_times shape", + ): + counts_proc._validate_shapes_match( + a, b, "subpop_indices", "period_end_times" + ) # =================================================================== @@ -336,6 +366,12 @@ def test_error_includes_lengths(self, counts_proc): with pytest.raises(ValueError, match="15.*30"): counts_proc._validate_obs_dense(obs, n_total=30) + def test_2d_obs_raises(self, counts_proc): + """2D obs (e.g., shape (n_total, 1)) should raise ValueError.""" + obs = jnp.ones((30, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + counts_proc._validate_obs_dense(obs, n_total=30) + # =================================================================== # PopulationCounts.validate_data() @@ -365,6 +401,36 @@ def test_wrong_length_obs_raises(self, counts_proc): with pytest.raises(ValueError, match="must equal n_total"): counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + def test_2d_obs_daily_raises(self, counts_proc): + """Daily regular-schedule obs must be 1D; 2D (n_total, 1) should raise.""" + obs = jnp.ones((30, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + + def test_2d_obs_weekly_raises(self): + """Weekly regular-schedule obs must be 1D; 2D (n_periods, 1) should raise.""" + proc = PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.3, 0.5, 0.2])), + noise=PoissonNoise(), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=5, + ) + n_total = 35 + first_day_dow = 6 + offset = (proc.period_end_dow + 1 - first_day_dow) % proc.aggregation_period + n_periods = (n_total - offset) // proc.aggregation_period + obs = jnp.ones((n_periods, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + proc.validate_data( + n_total=n_total, + n_subpops=1, + obs=obs, + first_day_dow=first_day_dow, + ) + def test_extra_kwargs_ignored(self, counts_proc): """validate_data should ignore extra keyword arguments.""" counts_proc.validate_data( @@ -433,7 +499,9 @@ def test_mismatched_obs_times_raises(self, subpop_proc): """validate_data with obs/period_end_times shape mismatch should raise.""" period_end_times = jnp.array([5, 10, 15]) obs = jnp.array([1.0, 2.0]) # length 2 != length 3 - with pytest.raises(ValueError, match="must match times shape"): + with pytest.raises( + ValueError, match=r"obs shape .* must match period_end_times shape" + ): subpop_proc.validate_data( n_total=30, n_subpops=3, @@ -460,6 +528,13 @@ def test_non_contiguous_subpop_indices_valid(self, subpop_proc): n_total=30, n_subpops=3, subpop_indices=subpop_indices ) + def test_scalar_subpop_indices_raises(self, subpop_proc): + """Scalar (0-D) subpop_indices should raise ValueError, not IndexError.""" + with pytest.raises(ValueError, match="must be 1D"): + subpop_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=jnp.asarray(0) + ) + # =================================================================== # MeasurementObservation.validate_data() @@ -555,7 +630,7 @@ def test_mismatched_obs_times_raises(self, measurements_proc): """validate_data with obs/times shape mismatch should raise.""" times = jnp.array([5, 10, 15]) obs = jnp.array([1.0, 2.0]) - with pytest.raises(ValueError, match="must match times shape"): + with pytest.raises(ValueError, match=r"obs shape .* must match times shape"): measurements_proc.validate_data( n_total=30, n_subpops=3, times=times, obs=obs ) From b23be720faf14cbb51fa03a1b767a7a9d5283874 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 22 Apr 2026 11:24:05 -0400 Subject: [PATCH 07/17] changes per co-pilot code review --- test/integration/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 2ba38544..68337b9c 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -23,7 +23,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import AR1 from pyrenew.latent.population_infections import PopulationInfections -from pyrenew.model import PyrenewBuilder, MultiSignalModel +from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import NegativeBinomialNoise, PopulationCounts from pyrenew.randomvariable import DistributionalVariable From c715e462c5debdb68c34e668ee979b0004faf815 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 22 Apr 2026 12:09:30 -0400 Subject: [PATCH 08/17] more unit tests --- test/test_observation_counts.py | 41 +++++++++++++++++++++++++++++++++ test/test_pyrenew_builder.py | 11 +++++++++ test/test_temporal_processes.py | 9 ++++++++ 3 files changed, 61 insertions(+) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index e072c9b4..ad1cf940 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -1036,6 +1036,12 @@ def test_daily_irregular_out_of_bounds_raises(self, daily_irregular_counts): n_total=20, n_subpops=1, period_end_times=period_end_times ) + def test_irregular_no_period_end_times_is_noop(self, weekly_irregular_counts): + """Irregular schedule with period_end_times=None returns without error.""" + weekly_irregular_counts.validate_data( + n_total=28, n_subpops=1, obs=None, period_end_times=None + ) + # =================================================================== # PopulationCounts with aggregation: sample @@ -1127,6 +1133,14 @@ def test_daily_irregular_period_indexing(self, daily_irregular_counts): assert result.predicted.shape == (30,) assert result.observed.shape == (3,) + def test_aggregate_helper_missing_first_day_dow_raises(self, weekly_regular_counts): + """_aggregate raises when aggregation_period > 1 and first_day_dow is None.""" + predicted_daily = jnp.ones(28) + with pytest.raises( + ValueError, match="first_day_dow is required when aggregation_period > 1" + ): + weekly_regular_counts._aggregate(predicted_daily, first_day_dow=None) + # =================================================================== # SubpopulationCounts with aggregation: validate_data @@ -1162,6 +1176,10 @@ def test_weekly_regular_valid_passes(self, weekly_regular_subpop_counts): subpop_indices=subpop_indices, ) + def test_regular_no_obs_is_noop(self, weekly_regular_subpop_counts): + """Regular schedule with obs=None returns without error.""" + weekly_regular_subpop_counts.validate_data(n_total=28, n_subpops=3, obs=None) + def test_weekly_regular_wrong_n_periods_raises(self, weekly_regular_subpop_counts): """Weekly-regular obs with wrong dim-0 length raises.""" obs = jnp.ones((28, 2)) * 5.0 @@ -1435,6 +1453,29 @@ def test_daily_irregular_fancy_indexing( assert result.predicted.shape == (30, 3) assert result.observed.shape == (3,) + def test_weekly_regular_with_obs_conditions( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """Weekly-regular sample conditions on 2D obs with NaN-padding for unobserved periods.""" + subpop_indices = jnp.array([0, 2]) + obs = jnp.array( + [ + [7.0, 7.0], + [jnp.nan, jnp.nan], + [7.0, 7.0], + [7.0, 7.0], + ] + ) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert result.observed.shape == (4, 2) + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 19497048..a0e094b8 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -581,6 +581,17 @@ def test_matching_weekly_period_end_dow_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) + def test_step_size_mismatches_coarse_period_raises(self): + """step_size > 1 must equal every coarse observation aggregation_period.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=2 + ), + observations=[_weekly_hosp_counts()], + ) + with pytest.raises(ValueError, match="All coarse cadences must agree"): + builder.build() + class TestMultiSignalValidateDataAnchor: """MultiSignalModel.validate_data sample-time anchor check for first_day_dow.""" diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 48ebd2df..34b141c2 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -279,6 +279,15 @@ def test_float_step_size_raises(self): with pytest.raises(ValueError, match="positive integer"): StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=7.0) + def test_repr_includes_inner_and_step_size(self): + """__repr__ shows the inner process and step_size.""" + inner = AR1(autoreg=0.9, innovation_sd=0.05) + wrapper = StepwiseTemporalProcess(inner, step_size=7) + rendered = repr(wrapper) + assert rendered.startswith("StepwiseTemporalProcess(") + assert f"inner={inner!r}" in rendered + assert "step_size=7" in rendered + class TestStepwiseTemporalProcessSample: """Sample-time behavior of StepwiseTemporalProcess.""" From 21498c5b5b589c3c1564b8f42565c2e386ae339c Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 22 Apr 2026 19:28:47 -0400 Subject: [PATCH 09/17] builder allows weekly or daily temporal processes --- .../tutorials/building_multisignal_models.qmd | 49 +++++ pyrenew/arrayutils.py | 22 ++ pyrenew/latent/population_infections.py | 15 +- pyrenew/latent/subpopulation_infections.py | 22 +- pyrenew/latent/temporal_processes.py | 138 +++++++++++-- pyrenew/model/multisignal_model.py | 5 + pyrenew/model/pyrenew_builder.py | 53 +++-- pyrenew/observation/base.py | 6 +- pyrenew/observation/count_observations.py | 50 ++++- test/conftest.py | 120 +++++++++++ test/test_population_infections.py | 45 ++++- test/test_pyrenew_builder.py | 190 +++++++++++++++++- test/test_subpopulation_infections.py | 44 ++++ test/test_temporal_processes.py | 104 ++++++++++ 14 files changed, 793 insertions(+), 70 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 23857e3a..5b97ea92 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -67,6 +67,7 @@ from pyrenew.latent import ( SubpopulationInfections, AR1, RandomWalk, + StepwiseTemporalProcess, GammaGroupSdPrior, HierarchicalNormalPrior, ) @@ -237,6 +238,54 @@ baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) subpop_rt_deviation_process = RandomWalk(innovation_sd=0.025) ``` +### Choosing the Rt Parameter Cadence + +The renewal equation is evaluated on the model's daily time axis, but the temporal process for $\mathcal{R}(t)$ does not have to sample a new parameter every day. +This separates three model choices: + +- **Parameter cadence**: how often the $\mathcal{R}(t)$ temporal process samples a new latent value +- **Model time axis**: the daily axis used by the renewal equation and delay convolutions +- **Observation cadence**: the scale of the likelihood for each signal, such as daily ED visits or weekly hospital admissions + +The AR(1) process above samples one value per model day: + +```python +baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) +``` + +To parameterize $\mathcal{R}(t)$ weekly while still running the renewal equation daily, wrap the temporal process in `StepwiseTemporalProcess`. +The wrapper samples a coarse trajectory and broadcasts it to the daily model axis before the latent infection process uses it. + +```python +weekly_baseline_rt_process = StepwiseTemporalProcess( + inner=AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, # Sunday +) +``` + +Use `alignment="calendar_week"` when the weekly Rt blocks should align to a known weekday. +For MMWR-style weeks ending on Saturday, use `week_start_dow=6` for Sunday starts; the corresponding weekly count observation would use `period_end_dow=5`. +At sample or run time, pass the day of week for element 0 of the shared model axis: + +```python +first_day_dow = model.compute_first_day_dow(obs_start_dow) + +model.run( + rng_key, + n_warmup=500, + n_samples=1000, + first_day_dow=first_day_dow, + hospital={"obs": hospital_obs, "first_day_dow": first_day_dow}, + ed={"obs": ed_obs}, +) +``` + +The top-level `first_day_dow` is used by calendar-aligned latent temporal processes. +The observation-level `first_day_dow` is used by observation processes that need calendar alignment, such as weekly aggregation or day-of-week effects. +For `alignment="model_index"`, stepwise blocks start at model index 0 and no top-level `first_day_dow` is needed. + ## Observation Processes Observation processes transform latent infections into observable signals and define the statistical model linking predictions to data. Each observation process: diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 64750d42..973466a4 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -24,6 +24,28 @@ def __repr__(self) -> str: return f"PeriodicProcessSample(value={self})" +def require_shape(arr: ArrayLike, expected: tuple[int, ...], label: str) -> None: + """ + Validate that ``arr.shape`` equals ``expected``. + + Parameters + ---------- + arr + Array whose shape is being checked. + expected + Required shape. + label + Name of the producing component, used in the error message. + + Raises + ------ + ValueError + If ``arr.shape`` does not equal ``expected``. + """ + if arr.shape != expected: + raise ValueError(f"{label} must return shape {expected}; got {arr.shape}") + + def tile_until_n( data: ArrayLike, n_timepoints: int, diff --git a/pyrenew/latent/population_infections.py b/pyrenew/latent/population_infections.py index 1d1ce1cf..e6feb7c3 100644 --- a/pyrenew/latent/population_infections.py +++ b/pyrenew/latent/population_infections.py @@ -9,6 +9,7 @@ from jax.typing import ArrayLike from numpyro.util import not_jax_tracer +from pyrenew.arrayutils import require_shape from pyrenew.deterministic import DeterministicVariable from pyrenew.distutil import validate_discrete_dist_vector from pyrenew.latent.base import ( @@ -147,13 +148,18 @@ def sample( self, n_days_post_init: int, subpop_fractions: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> LatentSample: """ Sample population infections using a single renewal process. - Generates a single $\\mathcal{R}(t)$ trajectory, computes initial infections via - exponential backprojection, and runs one renewal equation. + Generates a single daily $\\mathcal{R}(t)$ trajectory, computes initial + infections via exponential backprojection, and runs one deterministic + daily renewal equation. The temporal process may sample parameters at + any supported cadence (for example daily or weekly/stepwise), but it + must return a daily-length trajectory before the renewal equation is + evaluated. Parameters ---------- @@ -162,6 +168,9 @@ def sample( subpop_fractions Population fractions. Defaults to ``[1.0]`` (single population). Must be ``[1.0]`` if provided. + first_day_dow + Forwarded to ``single_rt_process``. See + [pyrenew.latent.TemporalProcess][]. **kwargs Additional arguments (unused, for compatibility) @@ -196,7 +205,9 @@ def sample( n_timepoints=n_total_days, initial_value=initial_log_rt, name_prefix="log_rt_single", + first_day_dow=first_day_dow, ) + require_shape(log_rt_single, (n_total_days, 1), "single_rt_process") rt_single = jnp.exp(log_rt_single) diff --git a/pyrenew/latent/subpopulation_infections.py b/pyrenew/latent/subpopulation_infections.py index 613f75f8..453365e3 100644 --- a/pyrenew/latent/subpopulation_infections.py +++ b/pyrenew/latent/subpopulation_infections.py @@ -11,6 +11,7 @@ import numpyro from jax.typing import ArrayLike +from pyrenew.arrayutils import require_shape from pyrenew.deterministic import DeterministicVariable from pyrenew.distutil import validate_discrete_dist_vector from pyrenew.latent.base import ( @@ -162,13 +163,18 @@ def sample( self, n_days_post_init: int, subpop_fractions: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> LatentSample: """ Sample hierarchical infections for all subpopulations. - Generates baseline $\\mathcal{R}(t)$, subpopulation deviations with sum-to-zero - constraint, initial infections, and runs n_subpops independent renewal processes. + Generates daily baseline $\\mathcal{R}(t)$, daily subpopulation + deviations with a sum-to-zero constraint, initial infections, and runs + n_subpops deterministic daily renewal processes. The temporal + processes may sample parameters at any supported cadence (for example + daily or weekly/stepwise), but they must return daily-length + trajectories before the renewal equations are evaluated. Parameters ---------- @@ -177,6 +183,10 @@ def sample( subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). Must sum to 1.0. + first_day_dow + Forwarded to ``baseline_rt_process`` and + ``subpop_rt_deviation_process``. See + [pyrenew.latent.TemporalProcess][]. **kwargs Additional arguments (unused, for compatibility) @@ -199,13 +209,21 @@ def sample( n_timepoints=n_total_days, initial_value=initial_log_rt, name_prefix="log_rt_baseline", + first_day_dow=first_day_dow, ) + require_shape(log_rt_baseline, (n_total_days, 1), "baseline_rt_process") deviations_raw = self.subpop_rt_deviation_process.sample( n_timepoints=n_total_days, n_processes=pop.n_subpops, initial_value=jnp.zeros(pop.n_subpops), name_prefix="subpop_deviations", + first_day_dow=first_day_dow, + ) + require_shape( + deviations_raw, + (n_total_days, pop.n_subpops), + "subpop_rt_deviation_process", ) # Sum-to-zero constraint ensures identifiability diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 5fef2101..98b55d81 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -50,7 +50,7 @@ from __future__ import annotations -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable import jax.numpy as jnp import numpyro @@ -60,6 +60,7 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.process.randomwalk import RandomWalk as ProcessRandomWalk from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import validate_dow, weekly_to_daily @runtime_checkable @@ -89,6 +90,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "temporal", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample temporal trajectory or trajectories. @@ -105,6 +108,10 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample site names to avoid collisions + first_day_dow + Day of week for element 0 of the shared model time axis + (0=Monday, ..., 6=Sunday). Standard temporal processes ignore + this value; calendar-aligned wrappers may use it. Returns ------- @@ -170,6 +177,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "ar1", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample AR(1) trajectory or trajectories. @@ -184,6 +193,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -277,6 +288,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "diff_ar1", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample differenced AR(1) trajectory or trajectories. @@ -291,6 +304,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -380,6 +395,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "rw", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample random walk trajectory or trajectories. @@ -394,6 +411,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -427,10 +446,11 @@ class StepwiseTemporalProcess(TemporalProcess): Parameterize an inner temporal process at a coarser cadence and broadcast to the per-timepoint scale by repetition. - Each ``step_size`` consecutive output timepoints share a single - sampled value from the inner process. Use to match R(t) - parametrization cadence to the coarsest observation cadence - (e.g., ``step_size=7`` with weekly-aggregated observations). + Each ``step_size`` consecutive output timepoints share a single sampled + value from the inner process. Use when a parameter should be estimated at + a coarser cadence while downstream model components still need one value + per evaluation timepoint. For example, a weekly-parameterized R(t) process + can return daily R(t) values for a daily deterministic renewal equation. Parameters ---------- @@ -440,14 +460,32 @@ class StepwiseTemporalProcess(TemporalProcess): step_size Number of per-timepoint units that share each inner sample. Must be a positive integer. + alignment + How repeated blocks align to the output time axis. ``"model_index"`` + starts blocks at output index 0. ``"calendar_week"`` aligns blocks + to calendar weeks using ``week_start_dow`` and the ``first_day_dow`` + supplied to ``sample()``. + week_start_dow + Day of week on which weekly blocks begin when + ``alignment="calendar_week"`` (0=Monday, ..., 6=Sunday). Required + for calendar-week alignment. Raises ------ ValueError - If ``step_size`` is not a positive integer. + If ``step_size`` is not a positive integer, or if alignment + arguments are inconsistent. """ - def __init__(self, inner: TemporalProcess, step_size: int) -> None: + _SUPPORTED_ALIGNMENTS = {"model_index", "calendar_week"} + + def __init__( + self, + inner: TemporalProcess, + step_size: int, + alignment: Literal["model_index", "calendar_week"] = "model_index", + week_start_dow: int | None = None, + ) -> None: """ Initialize stepwise temporal process. @@ -457,37 +495,96 @@ def __init__(self, inner: TemporalProcess, step_size: int) -> None: Inner ``TemporalProcess`` that generates the coarse trajectory. step_size Number of per-timepoint units that share each inner sample. + alignment + How repeated blocks align to the output time axis. ``"model_index"`` + starts blocks at output index 0. ``"calendar_week"`` aligns + weekly blocks to ``week_start_dow`` using ``first_day_dow`` at + sample time. + week_start_dow + Day of week on which weekly blocks begin when + ``alignment="calendar_week"`` (0=Monday, ..., 6=Sunday). Raises ------ ValueError - If ``step_size`` is not a positive integer. + If ``step_size`` is not a positive integer, or if alignment + arguments are inconsistent. """ if not isinstance(step_size, int) or step_size < 1: raise ValueError(f"step_size must be a positive integer, got {step_size!r}") + if alignment not in self._SUPPORTED_ALIGNMENTS: + raise ValueError( + f"alignment must be one of {self._SUPPORTED_ALIGNMENTS}, " + f"got {alignment!r}" + ) + if alignment == "calendar_week": + if step_size != 7: + raise ValueError( + "calendar_week alignment requires step_size=7, " + f"got step_size={step_size}" + ) + if week_start_dow is None: + raise ValueError( + "week_start_dow is required when alignment='calendar_week'" + ) + validate_dow(week_start_dow, "week_start_dow") + elif week_start_dow is not None: + raise ValueError( + "week_start_dow is only used when alignment='calendar_week'" + ) self.inner = inner self.step_size = step_size + self.alignment = alignment + self.week_start_dow = week_start_dow def __repr__(self) -> str: """Return string representation.""" return ( - f"StepwiseTemporalProcess(inner={self.inner!r}, step_size={self.step_size})" + f"StepwiseTemporalProcess(inner={self.inner!r}, " + f"step_size={self.step_size}, alignment={self.alignment!r}, " + f"week_start_dow={self.week_start_dow!r})" ) + def _resolve_n_coarse(self, n_timepoints: int, first_day_dow: int | None) -> int: + """ + Return the number of inner-process samples needed. + + Returns + ------- + int + Number of coarse samples required to cover ``n_timepoints`` under + the configured alignment. + """ + if self.alignment == "model_index": + return (n_timepoints + self.step_size - 1) // self.step_size + if first_day_dow is None: + raise ValueError( + "first_day_dow is required at sample time when " + "alignment='calendar_week'" + ) + validate_dow(first_day_dow, "first_day_dow") + trim = (first_day_dow - self.week_start_dow) % 7 + return (n_timepoints + trim + 6) // 7 + def sample( self, n_timepoints: int, initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "stepwise", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample coarse trajectory from inner process and broadcast. - Computes ``n_steps = ceil(n_timepoints / step_size)``, samples - the inner process at that cadence, then repeats each coarse - value ``step_size`` times along the time axis and trims to - ``n_timepoints``. + Computes the number of coarse time steps needed for the requested + alignment, samples the inner process at that cadence, then broadcasts + each coarse value to the per-timepoint axis and trims to + ``n_timepoints``. The returned value always has one row per evaluation + timepoint, regardless of the inner parameter cadence. The coarse + trajectory is recorded as a NumPyro deterministic site named + ``"{name_prefix}_coarse"``. Parameters ---------- @@ -499,6 +596,10 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites; forwarded to the inner process. + first_day_dow + Day of week for element 0 of the shared model time axis + (0=Monday, ..., 6=Sunday). Required when + ``alignment="calendar_week"``. Returns ------- @@ -506,11 +607,18 @@ def sample( Trajectories of shape ``(n_timepoints, n_processes)``, constant within each block of ``step_size`` consecutive rows. """ - n_steps = (n_timepoints + self.step_size - 1) // self.step_size + n_steps = self._resolve_n_coarse(n_timepoints, first_day_dow) coarse = self.inner.sample( n_timepoints=n_steps, initial_value=initial_value, n_processes=n_processes, name_prefix=name_prefix, ) - return jnp.repeat(coarse, repeats=self.step_size, axis=0)[:n_timepoints] + numpyro.deterministic(f"{name_prefix}_coarse", coarse) + if self.alignment == "model_index": + return jnp.repeat(coarse, repeats=self.step_size, axis=0)[:n_timepoints] + return weekly_to_daily( + coarse, + week_start_dow=self.week_start_dow, + output_data_first_dow=first_day_dow, + )[:n_timepoints] diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 768eee25..616471c7 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -247,6 +247,7 @@ def sample( population_size: float, *, subpop_fractions: ArrayLike | None = None, + first_day_dow: int | None = None, **observation_data: dict[str, object], ) -> None: """ @@ -263,6 +264,9 @@ def sample( (from latent process) to infection counts (for observation processes). subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). + first_day_dow + Forwarded to the latent process. See + [pyrenew.latent.TemporalProcess][]. **observation_data Data for each observation process, keyed by observation name (the ``name`` attribute of each observation process). @@ -280,6 +284,7 @@ def sample( latent_sample = self.latent.sample( n_days_post_init=n_days_post_init, subpop_fractions=subpop_fractions, + first_day_dow=first_day_dow, ) # Scale from proportions to counts diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index 1258d932..c6447d49 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -198,34 +198,27 @@ def compute_n_initialization_points(self) -> int: def _validate_coherence(self) -> None: """ - Enforce end-to-end coherence between R(t) cadence and observation cadences. - - Called at the start of ``build()`` before any model components are - constructed. Inspects ``self.observations`` for ``aggregation_period`` - and ``period_end_dow`` attributes, and walks ``self.latent_params`` - values for any ``TemporalProcess`` instances to read their - ``step_size`` attribute. + Enforce calendar-anchor and structural coherence across components. Checks: - All observations sharing the same ``aggregation_period > 1`` must agree on ``period_end_dow``. - - Every temporal-process ``step_size`` must be ``<=`` the finest - observation ``aggregation_period``. - - If a temporal process has ``step_size > 1``, that ``step_size`` - must equal the ``aggregation_period`` of every observation whose - ``aggregation_period > 1``. + - Every temporal-process ``step_size`` must be a positive integer. + - Calendar-week-aligned temporal processes must have a + ``week_start_dow`` consistent with the ``period_end_dow`` of any + weekly observation: ``period_end_dow == (week_start_dow + 6) % 7``. Raises ------ ValueError - If any of the three rules is violated. + If any of the above checks fails. """ + # Intentionally permissive: parameter cadence and observation cadence + # are independent. agg_by_dow: dict[int, set[int]] = {} - agg_periods: list[int] = [] for name, obs in self.observations.items(): P = getattr(obs, "aggregation_period", 1) - agg_periods.append(P) if P > 1: agg_by_dow.setdefault(P, set()).add(getattr(obs, "period_end_dow")) @@ -242,26 +235,28 @@ def _validate_coherence(self) -> None: if isinstance(value, TemporalProcess) } - finest = min(agg_periods) if agg_periods else 1 - coarse_periods = {P for P in agg_periods if P > 1} + weekly_period_end_dow = next(iter(agg_by_dow.get(7, set())), None) for param_name, process in temporal_processes.items(): step_size = getattr(process, "step_size", 1) - if step_size > finest: + if not isinstance(step_size, int) or step_size < 1: raise ValueError( - f"Temporal process '{param_name}' has step_size={step_size} " - f"exceeding the finest observation aggregation_period " - f"({finest}). Parameterize R(t) at least as finely as the " - f"finest observed signal." + f"Temporal process '{param_name}' must expose a positive " + f"integer step_size; got {step_size!r}" ) - if step_size > 1: - mismatched = {P for P in coarse_periods if P != step_size} - if mismatched: + if ( + getattr(process, "alignment", None) == "calendar_week" + and weekly_period_end_dow is not None + ): + week_start_dow = getattr(process, "week_start_dow", None) + expected_period_end_dow = (week_start_dow + 6) % 7 + if expected_period_end_dow != weekly_period_end_dow: raise ValueError( - f"Temporal process '{param_name}' has step_size=" - f"{step_size} but observations include " - f"aggregation_period(s) {sorted(mismatched)} that do " - f"not match. All coarse cadences must agree." + f"Temporal process '{param_name}' has " + f"week_start_dow={week_start_dow}, which implies " + f"weekly observations should end on " + f"period_end_dow={expected_period_end_dow}; got " + f"period_end_dow={weekly_period_end_dow}" ) def build(self) -> MultiSignalModel: diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index c7bc6e3b..136ce89b 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -14,6 +14,7 @@ import numpyro from jax.typing import ArrayLike +from pyrenew.arrayutils import require_shape from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.metaclass import RandomVariable @@ -191,10 +192,7 @@ def _validate_dow_effect( ValueError If shape is not (7,) or any values are negative. """ - if dow_effect.shape != (7,): - raise ValueError( - f"{param_name} must return shape (7,), got {dow_effect.shape}" - ) + require_shape(dow_effect, (7,), param_name) if jnp.any(dow_effect < 0): raise ValueError(f"{param_name} must have non-negative values") diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 3a7ea06b..8b70f217 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -22,7 +22,10 @@ class CountObservation(BaseObservationProcess): Abstract Base class for count observation processes. Subclasses map infections to counts through ascertainment x delay convolution - with composable noise model. + with composable noise model. Count observations always receive predictions + on the model's daily time axis and then, if requested, aggregate those + daily predictions to the observation reporting grid before evaluating the + likelihood. """ _SUPPORTED_SCHEDULES = ("regular", "irregular") @@ -69,9 +72,11 @@ def __init__( ``first_day_dow`` at sample time), predicted counts are scaled by a periodic weekly pattern. aggregation_period - Width of the reporting period in fundamental time units. - Must be in ``{1, 7}``. ``1`` means no aggregation (daily - observations). + Width of the observation reporting period in days. Must be in + ``{1, 7}``. ``1`` means no aggregation (daily observations). + This controls only the scale on which the count likelihood is + evaluated. It does not control how often the latent Rt temporal + process samples new parameters. reporting_schedule Either ``"regular"`` (dense observation array, one entry per period, NaN for unobserved periods) or ``"irregular"`` @@ -289,13 +294,15 @@ def _aggregate( first_day_dow: int | None, ) -> ArrayLike: """ - Aggregate daily predicted counts to the reporting-period grid. + Aggregate daily predicted counts to the observation reporting grid. When ``aggregation_period == 1`` returns the input unchanged. Otherwise sums daily values over non-overlapping fixed-width periods anchored by ``period_end_dow``, via ``pyrenew.time.daily_to_weekly``. Works on both 1D ``(n_total,)`` and 2D ``(n_total, n_subpops)`` inputs. + This aggregation is part of the observation likelihood path and is + independent of the parameter cadence used by the latent Rt process. Parameters ---------- @@ -335,7 +342,10 @@ class PopulationCounts(CountObservation): Aggregated count observation. Maps aggregate infections to counts through ascertainment x delay - convolution with composable noise model. + convolution with composable noise model. Predictions are constructed on + the daily model axis; ``aggregation_period`` controls whether those + predictions are scored as daily counts or summed to weekly counts before + the likelihood. Parameters ---------- @@ -402,7 +412,8 @@ def validate_data( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period > 1``. + ``aggregation_period > 1`` so weekly observation periods can be + aligned to the shared daily model axis. **kwargs Additional keyword arguments (ignored). @@ -475,6 +486,11 @@ def sample( aggregated array at period indices derived from ``period_end_times``. + ``aggregation_period`` describes the observation scale only. The + latent infection process may use daily or coarser Rt parameter + cadence, but by the time this method is called it supplies infections + on the daily model axis. + Parameters ---------- infections @@ -494,7 +510,9 @@ def sample( Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when ``day_of_week_rv`` was set at construction or when - ``aggregation_period > 1``. + ``aggregation_period > 1``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. @@ -573,6 +591,10 @@ class SubpopulationCounts(CountObservation): Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model. + Predictions are constructed on the daily model axis for each + subpopulation; ``aggregation_period`` controls whether those predictions + are scored as daily counts or summed to weekly counts before the + likelihood. Parameters ---------- @@ -640,7 +662,8 @@ def validate_data( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period > 1``. + ``aggregation_period > 1`` so weekly observation periods can be + aligned to the shared daily model axis. subpop_indices Subpopulation indices (0-indexed). For ``reporting_schedule="regular"``: shape @@ -737,6 +760,11 @@ def sample( a mask; ``"irregular"`` fancy-indexes the aggregated array at period indices derived from ``period_end_times``. + ``aggregation_period`` describes the observation scale only. The + latent infection process may use daily or coarser Rt parameter + cadence, but by the time this method is called it supplies infections + on the daily model axis. + Parameters ---------- infections @@ -757,7 +785,9 @@ def sample( Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when ``day_of_week_rv`` was set at construction or when - ``aggregation_period > 1``. + ``aggregation_period > 1``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. period_end_times Daily-axis indices of each observed period's final day. Required when ``reporting_schedule == "irregular"``. diff --git a/test/conftest.py b/test/conftest.py index bf282e2f..d41bf2c9 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -459,3 +459,123 @@ def mmwr_saturday_indices_first_three(): Shape ``(3,)`` containing ``[6, 13, 20]``. """ return jnp.array([6, 13, 20]) + + +# ============================================================================= +# Temporal Process Stubs +# ============================================================================= + + +class WrongShapeTemporalProcess: + """Temporal process stub that returns a fixed wrong-shaped array.""" + + step_size = 1 + + def __init__(self, value): + """ + Store the wrong-shaped value to return from ``sample``. + + Parameters + ---------- + value + Array returned by ``sample`` regardless of requested shape. + """ + self.value = value + + def sample(self, **kwargs): + """ + Return the configured wrong-shaped value. + + Returns + ------- + ArrayLike + The array passed to ``__init__``. + """ + return self.value + + +class ConstantTemporalProcess: + """Temporal process stub that returns zeros with the requested shape.""" + + step_size = 1 + + def sample(self, n_timepoints, n_processes=1, **kwargs): + """ + Return a correctly shaped zero trajectory. + + Parameters + ---------- + n_timepoints + Number of time points. + n_processes + Number of parallel processes. + + Returns + ------- + jnp.ndarray + Zeros of shape ``(n_timepoints, n_processes)``. + """ + return jnp.zeros((n_timepoints, n_processes)) + + +class InvalidStepSizeTemporalProcess: + """Temporal process stub with invalid builder-inspected metadata.""" + + step_size = 0 + + def sample(self, **kwargs): + """ + Return an arbitrary array. + + Builder validation should reject this process before ``sample`` runs, + so the returned value is irrelevant. + + Returns + ------- + jnp.ndarray + Shape ``(1, 1)`` array of zeros. + """ + return jnp.zeros((1, 1)) + + +@pytest.fixture +def wrong_shape_temporal_process_cls(): + """ + Class producing a temporal-process stub that returns a fixed wrong shape. + + Returns + ------- + type + The ``WrongShapeTemporalProcess`` class. Instantiate with the array + the stub should return from ``sample``. + """ + return WrongShapeTemporalProcess + + +@pytest.fixture +def constant_temporal_process(): + """ + Temporal-process stub that returns a correctly shaped zero trajectory. + + Returns + ------- + ConstantTemporalProcess + Instance whose ``sample`` returns zeros of the requested shape. + """ + return ConstantTemporalProcess() + + +@pytest.fixture +def invalid_step_size_temporal_process(): + """ + Temporal-process stub that advertises ``step_size=0``. + + Used to exercise ``PyrenewBuilder._validate_coherence`` rejection of + temporal processes whose metadata is structurally invalid. + + Returns + ------- + InvalidStepSizeTemporalProcess + Instance with ``step_size=0``. + """ + return InvalidStepSizeTemporalProcess() diff --git a/test/test_population_infections.py b/test/test_population_infections.py index 261ac8f7..2c9c9683 100644 --- a/test/test_population_infections.py +++ b/test/test_population_infections.py @@ -7,7 +7,7 @@ import pytest from pyrenew.deterministic import DeterministicVariable -from pyrenew.latent import RandomWalk +from pyrenew.latent import AR1, RandomWalk, StepwiseTemporalProcess from pyrenew.latent.population_infections import PopulationInfections @@ -125,6 +125,49 @@ def test_custom_name_prefix(self, gen_int_rv): assert "my_infections::rt_single" in trace + def test_first_day_dow_reaches_calendar_aligned_rt_process(self, gen_int_rv): + """Calendar-aligned stepwise Rt receives model-axis day of week.""" + process = PopulationInfections( + name="population", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + n_initialization_points=7, + ) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + process.sample(n_days_post_init=10, first_day_dow=3) + + log_rt = trace["population::log_rt_single"]["value"] + coarse = trace["log_rt_single_coarse"]["value"] + + assert log_rt.shape == (17, 1) + assert coarse.shape == (3, 1) + assert jnp.allclose(log_rt[:3], log_rt[0]) + assert jnp.allclose(log_rt[3:10], log_rt[3]) + assert jnp.allclose(log_rt[10:17], log_rt[10]) + + def test_wrong_rt_shape_raises(self, gen_int_rv, wrong_shape_temporal_process_cls): + """Temporal processes must return daily-length single-process Rt.""" + process = PopulationInfections( + name="population", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=wrong_shape_temporal_process_cls(jnp.zeros((16, 2))), + n_initialization_points=7, + ) + + with pytest.raises(ValueError, match="single_rt_process must return shape"): + process.sample(n_days_post_init=10) + class TestPopulationInfectionsValidation: """Test validation of inputs.""" diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index a0e094b8..12c1ef20 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -3,6 +3,7 @@ """ import jax.numpy as jnp +import numpyro import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable @@ -275,6 +276,115 @@ def test_prior_predictive_multi_signal(self, simple_builder): # All prior predictive infections should be positive assert jnp.all(prior_samples["latent_infections"] > 0) + def test_first_day_dow_reaches_calendar_aligned_latent_process(self): + """MultiSignalModel forwards model-axis day of week to the latent process.""" + latent = PopulationInfections( + name="PopulationInfections", + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + n_initialization_points=3, + ) + model = MultiSignalModel(latent, {"ed": _daily_ed_counts()}) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + model.sample( + n_days_post_init=10, + population_size=1_000_000, + first_day_dow=3, + ed={"obs": None}, + ) + + log_rt = trace["PopulationInfections::log_rt_single"]["value"] + coarse = trace["log_rt_single_coarse"]["value"] + + assert log_rt.shape == (model.latent.n_initialization_points + 10, 1) + assert coarse.shape[0] < log_rt.shape[0] + assert jnp.allclose(log_rt[:3], log_rt[0]) + assert jnp.allclose(log_rt[3:10], log_rt[3]) + + def test_missing_first_day_dow_for_calendar_aligned_latent_process_raises(self): + """Calendar-aligned latent temporal processes require model-axis DOW.""" + latent = PopulationInfections( + name="PopulationInfections", + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + n_initialization_points=3, + ) + model = MultiSignalModel(latent, {"ed": _daily_ed_counts()}) + + with numpyro.handlers.seed(rng_seed=42): + with pytest.raises(ValueError, match="first_day_dow"): + model.sample( + n_days_post_init=10, + population_size=1_000_000, + ed={"obs": None}, + ) + + def test_builder_mixed_cadence_weekly_rt_samples(self): + """ + Mixed daily/weekly observations can use weekly-parameterized Rt. + + The latent temporal process records a coarse Rt trajectory, the latent + process records daily Rt values, ED remains on the daily likelihood + scale, and hospital observations are scored on the weekly likelihood + scale. + """ + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + model = builder.build() + n_days_post_init = 28 + n_total = model.latent.n_initialization_points + n_days_post_init + first_day_dow = 6 + hospital_obs = jnp.array([jnp.nan, 5.0, 7.0, 6.0], dtype=float) + ed_obs = jnp.concatenate( + [ + jnp.full(model.latent.n_initialization_points, jnp.nan), + jnp.ones(n_days_post_init) * 3.0, + ] + ) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + model.sample( + n_days_post_init=n_days_post_init, + population_size=1_000_000, + first_day_dow=first_day_dow, + hospital={"obs": hospital_obs, "first_day_dow": first_day_dow}, + ed={"obs": ed_obs}, + ) + + assert trace["log_rt_single_coarse"]["value"].shape == (5, 1) + assert trace["PopulationInfections::log_rt_single"]["value"].shape == ( + n_total, + 1, + ) + assert trace["ed_obs"]["value"].shape == (n_total,) + assert trace["hospital_obs"]["value"].shape == hospital_obs.shape + assert trace["hospital_predicted_daily"]["value"].shape == (n_total,) + assert trace["hospital_predicted"]["value"].shape == hospital_obs.shape + class TestMultiSignalModelValidation: """Test data validation.""" @@ -546,16 +656,16 @@ def test_daily_rt_with_mixed_observations_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_weekly_rt_with_daily_observation_raises(self): - """step_size=7 with any daily obs: rule 2 violation.""" + def test_weekly_rt_with_daily_observation_passes(self): + """Coarse Rt parameter cadence is allowed with daily observations.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 ), observations=[_weekly_hosp_counts(), _daily_ed_counts()], ) - with pytest.raises(ValueError, match="exceeding the finest"): - builder.build() + model = builder.build() + assert isinstance(model, MultiSignalModel) def test_mismatched_weekly_period_end_dow_raises(self): """Two weekly observations with different period_end_dow: rule 1 violation.""" @@ -581,17 +691,83 @@ def test_matching_weekly_period_end_dow_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_step_size_mismatches_coarse_period_raises(self): - """step_size > 1 must equal every coarse observation aggregation_period.""" + def test_arbitrary_step_size_with_weekly_observation_passes(self): + """Parameter cadence need not match observation aggregation period.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=2 ), observations=[_weekly_hosp_counts()], ) - with pytest.raises(ValueError, match="All coarse cadences must agree"): + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_invalid_temporal_process_step_size_raises( + self, invalid_step_size_temporal_process + ): + """Temporal process metadata must expose a positive integer step_size.""" + builder = _coherence_builder( + single_rt_process=invalid_step_size_temporal_process, + observations=[_daily_ed_counts()], + ) + with pytest.raises(ValueError, match="positive integer step_size"): builder.build() + def test_calendar_week_alignment_matches_weekly_period_end_dow_passes(self): + """Sunday-start weeks pair with Saturday-ending weekly observations.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + observations=[_weekly_hosp_counts(period_end_dow=5)], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_calendar_week_alignment_mismatches_weekly_period_end_dow_raises(self): + """A weekly Rt anchor must agree with the weekly observation anchor.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=0, + ), + observations=[_weekly_hosp_counts(period_end_dow=5)], + ) + with pytest.raises( + ValueError, match="weekly observations should end on period_end_dow=6" + ): + builder.build() + + def test_calendar_week_alignment_with_only_daily_observations_passes(self): + """No weekly observation means no calendar-anchor agreement to enforce.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + observations=[_daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_model_index_alignment_ignores_weekly_period_end_dow(self): + """Model-index alignment carries no calendar anchor to compare against.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts(period_end_dow=5)], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + class TestMultiSignalValidateDataAnchor: """MultiSignalModel.validate_data sample-time anchor check for first_day_dow.""" diff --git a/test/test_subpopulation_infections.py b/test/test_subpopulation_infections.py index 658546e3..cc0adf58 100644 --- a/test/test_subpopulation_infections.py +++ b/test/test_subpopulation_infections.py @@ -84,6 +84,50 @@ def test_shape_and_positivity_across_subpop_counts( expected = jnp.sum(inf_all * fractions[jnp.newaxis, :], axis=1) assert jnp.allclose(inf_juris, expected, atol=1e-6) + def test_wrong_baseline_rt_shape_raises( + self, gen_int_rv, wrong_shape_temporal_process_cls + ): + """Baseline temporal process must return one daily baseline trajectory.""" + process = SubpopulationInfections( + name="subpopulation", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=wrong_shape_temporal_process_cls(jnp.zeros((16, 2))), + subpop_rt_deviation_process=RandomWalk(), + n_initialization_points=7, + ) + + with pytest.raises(ValueError, match="baseline_rt_process must return shape"): + process.sample( + n_days_post_init=10, + subpop_fractions=jnp.array([0.3, 0.25, 0.45]), + ) + + def test_wrong_subpop_deviation_shape_raises( + self, gen_int_rv, wrong_shape_temporal_process_cls, constant_temporal_process + ): + """Deviation temporal process must return one trajectory per subpop.""" + process = SubpopulationInfections( + name="subpopulation", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=constant_temporal_process, + subpop_rt_deviation_process=wrong_shape_temporal_process_cls( + jnp.zeros((17, 2)) + ), + n_initialization_points=7, + ) + + with pytest.raises( + ValueError, match="subpop_rt_deviation_process must return shape" + ): + process.sample( + n_days_post_init=10, + subpop_fractions=jnp.array([0.3, 0.25, 0.45]), + ) + class TestSubpopulationInfectionsValidation: """Test validation of inputs.""" diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 34b141c2..f9c6628f 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -288,6 +288,43 @@ def test_repr_includes_inner_and_step_size(self): assert f"inner={inner!r}" in rendered assert "step_size=7" in rendered + def test_unknown_alignment_raises(self): + """Unknown alignment names raise.""" + with pytest.raises(ValueError, match="alignment"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="weekly", + ) + + def test_calendar_week_requires_weekly_step_size(self): + """calendar_week alignment is explicitly weekly.""" + with pytest.raises(ValueError, match="step_size=7"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=3, + alignment="calendar_week", + week_start_dow=6, + ) + + def test_calendar_week_requires_week_start_dow(self): + """calendar_week alignment requires a declared week start.""" + with pytest.raises(ValueError, match="week_start_dow"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + ) + + def test_model_index_rejects_week_start_dow(self): + """week_start_dow is only meaningful for calendar_week alignment.""" + with pytest.raises(ValueError, match="week_start_dow"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + week_start_dow=6, + ) + class TestStepwiseTemporalProcessSample: """Sample-time behavior of StepwiseTemporalProcess.""" @@ -337,6 +374,73 @@ def test_step_size_one_passthrough_shape(self): result = wrapper.sample(n_timepoints=20, n_processes=2) assert result.shape == (20, 2) + def test_calendar_week_requires_first_day_dow_at_sample_time(self): + """calendar_week alignment needs the model-axis day of week.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ) + with numpyro.handlers.seed(rng_seed=42): + with pytest.raises(ValueError, match="first_day_dow"): + wrapper.sample(n_timepoints=20, n_processes=1) + + def test_calendar_week_alignment_with_leading_partial_week(self): + """calendar_week alignment starts full blocks on week_start_dow.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ) + # first_day_dow=3 means day 0 is Thursday. With Sunday week starts, + # days 0-2 are a leading partial week, then days 3-9 are the first + # full Sunday-Saturday block on the model axis. + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample( + n_timepoints=17, + n_processes=1, + first_day_dow=3, + ) + + assert jnp.allclose(result[:3], result[0]) + assert jnp.allclose(result[3:10], result[3]) + assert jnp.allclose(result[10:17], result[10]) + assert not jnp.allclose(result[2], result[3]) + + def test_calendar_week_alignment_without_leading_partial_week(self): + """calendar_week alignment handles model axes starting on week_start_dow.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample( + n_timepoints=15, + n_processes=1, + first_day_dow=6, + ) + + assert jnp.allclose(result[:7], result[0]) + assert jnp.allclose(result[7:14], result[7]) + assert jnp.allclose(result[14:15], result[14]) + + def test_coarse_trajectory_is_recorded(self): + """StepwiseTemporalProcess records the coarse trajectory.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ) + traced = numpyro.handlers.trace( + numpyro.handlers.seed(wrapper.sample, rng_seed=42) + ).get_trace(n_timepoints=15, n_processes=1, name_prefix="rt") + + assert "rt_coarse" in traced + assert traced["rt_coarse"]["type"] == "deterministic" + assert traced["rt_coarse"]["value"].shape == (3, 1) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 46f1ea0e07494f83297d5083c0df77ce7a2e4de1 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Wed, 22 Apr 2026 20:09:38 -0400 Subject: [PATCH 10/17] code cleanup --- pyrenew/datasets/datagen_he_CA_126.py | 16 ++--- pyrenew/observation/base.py | 4 +- pyrenew/observation/count_observations.py | 9 ++- pyrenew/time.py | 49 ++++++++++++++- test/integration/conftest.py | 76 ++++++++++++++++++++++- test/test_datagen_he_CA_126.py | 8 +-- 6 files changed, 142 insertions(+), 20 deletions(-) diff --git a/pyrenew/datasets/datagen_he_CA_126.py b/pyrenew/datasets/datagen_he_CA_126.py index 476c8036..c1358060 100644 --- a/pyrenew/datasets/datagen_he_CA_126.py +++ b/pyrenew/datasets/datagen_he_CA_126.py @@ -210,7 +210,7 @@ def run_renewal( def apply_day_of_week_effects( - values: np.ndarray, dow_effects: np.ndarray, first_dow: int + values: np.ndarray, dow_effects: np.ndarray, first_day_dow: int ) -> np.ndarray: """ Apply multiplicative day-of-week effects to daily values. @@ -222,7 +222,7 @@ def apply_day_of_week_effects( dow_effects : np.ndarray Multiplicative effects for each day (length 7, ISO convention: 0 = Monday, 6 = Sunday). Should sum to 7 to preserve weekly totals. - first_dow : int + first_day_dow : int ISO day-of-week of the first element in values. Returns @@ -230,7 +230,7 @@ def apply_day_of_week_effects( np.ndarray Adjusted daily values. """ - day_indices = get_sequential_day_of_week_indices(first_dow, len(values)) + day_indices = get_sequential_day_of_week_indices(first_day_dow, len(values)) return values * dow_effects[day_indices] @@ -282,8 +282,8 @@ def build_weekly_hosp_frame( pl.DataFrame Columns: week_end, weekly_hosp_admits. """ - first_dow = start_date.weekday() - days_to_first_sunday = (6 - first_dow) % 7 + first_day_dow = start_date.weekday() + days_to_first_sunday = (6 - first_day_dow) % 7 first_week_end = start_date + timedelta(days=days_to_first_sunday + 6) n_weeks = len(weekly_values) week_ends = [first_week_end + timedelta(weeks=i) for i in range(n_weeks)] @@ -308,7 +308,7 @@ def generate() -> None: infections_obs = infections_full[N_INIT:] obs_dates = [START_DATE + timedelta(days=i) for i in range(n_days)] - first_dow = START_DATE.weekday() + first_day_dow = START_DATE.weekday() expected_hosp_daily, _ = compute_delay_ascertained_incidence( latent_incidence=infections_full, @@ -327,7 +327,7 @@ def generate() -> None: # generative assumption of PopulationCounts(aggregation_period=7) with a # single NegativeBinomialNoise at the reporting cadence. expected_hosp_weekly = np.asarray( - daily_to_mmwr_epiweekly(expected_hosp_daily, input_data_first_dow=first_dow) + daily_to_mmwr_epiweekly(expected_hosp_daily, input_data_first_dow=first_day_dow) ) expected_hosp_weekly = np.maximum(expected_hosp_weekly, 1.0) hosp_weekly_obs = sample_negbinom( @@ -352,7 +352,7 @@ def generate() -> None: pad=True, ) expected_ed = expected_ed[N_INIT:] - expected_ed = apply_day_of_week_effects(expected_ed, DOW_EFFECTS, first_dow) + expected_ed = apply_day_of_week_effects(expected_ed, DOW_EFFECTS, first_day_dow) expected_ed = np.maximum(expected_ed, 1.0) ed_obs = sample_negbinom(expected_ed, NEGBINOM_CONCENTRATION_ED, rng) diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 136ce89b..5abb076b 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -215,7 +215,9 @@ def _validate_aggregation_params( aggregation_period Width of the reporting period in fundamental time units. period_end_dow - Day-of-week index of each period's final day. + Day-of-week index of each weekly period's final day, used as + the calendar anchor for weekly aggregation (e.g., 5 for MMWR + Sunday-Saturday epiweeks; 6 for ISO Monday-Sunday weeks). Raises ------ diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 8b70f217..c6e4b670 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -83,9 +83,12 @@ def __init__( (sparse observation array with user-supplied period-end time indices). period_end_dow - Day-of-week index of each period's final day - (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period == 7``; ignored otherwise. + Day-of-week index of each weekly period's final day (e.g., 5 + for MMWR Sunday-Saturday epiweeks; 6 for ISO Monday-Sunday + weeks). Required when ``aggregation_period == 7``; ignored + otherwise. Anchors the weekly likelihood: daily predictions + are bucketed into weeks ending on this day, then summed before + scoring. (0=Monday, 6=Sunday, ISO convention.) Raises ------ diff --git a/pyrenew/time.py b/pyrenew/time.py index 8d9357c3..54021917 100644 --- a/pyrenew/time.py +++ b/pyrenew/time.py @@ -1,8 +1,51 @@ """ -Helper functions for handling timeseries in Pyrenew +Helper functions for handling timeseries in Pyrenew. -Days of the week in pyrenew are 0-indexed and follow -ISO standards, so 0 is Monday at 6 is Sunday. +Days of the week in pyrenew are 0-indexed and follow ISO standards, +so 0 is Monday and 6 is Sunday. + +Calendar concepts used elsewhere in pyrenew +------------------------------------------- + +Three day-of-week values appear across the codebase. They name +distinct things; do not conflate them. + +- ``first_day_dow`` — day-of-week of element 0 of a model's shared + daily time axis. Sample-time data fact (depends on the dataset + being fit). Threaded through ``MultiSignalModel.sample`` and the + ``TemporalProcess.sample`` protocol. + +- ``period_end_dow`` — day-of-week on which each weekly observation + period ends. Construction-time modeling choice on + ``CountObservation``. MMWR convention: 5 (Saturday). ISO + convention: 6 (Sunday). + +- ``week_start_dow`` — day-of-week on which weekly Rt blocks begin. + Construction-time modeling choice on ``StepwiseTemporalProcess`` + with ``alignment="calendar_week"``. Also the parameter name used by + ``daily_to_weekly`` / ``weekly_to_daily`` in this module. + +The two anchors are complements: a weekly observation period that ends +on day X is contained in the calendar week that starts on day +``(X + 1) % 7``. ``PyrenewBuilder._validate_coherence`` enforces +``period_end_dow == (week_start_dow + 6) % 7`` when both are present. + +Worked example: ``first_day_dow=3`` (Thursday), ``week_start_dow=6`` +(Sunday), 17-day axis:: + + model index: 0 1 2 | 3 4 5 6 7 8 9 | 10 11 12 13 14 15 16 + weekday: Th Fr Sa | Su Mo Tu We Th Fr Sa | Su Mo Tu We Th Fr Sa + weekly Rt: [ c0 ] [ c1 ] [ c2 ] + +The first three indices fall in the leading partial week (``c0``); +each subsequent block of 7 days falls in one full Sunday-Saturday +week and shares one Rt sample. + +The conversion between the model-axis fact and these calendar anchors +happens at call sites (e.g., ``daily_to_weekly(daily_predicted, +input_data_first_dow=first_day_dow, week_start_dow=...)``); functions +in this module use locally scoped role names like +``input_data_first_dow`` and ``output_data_first_dow``. """ import datetime as dt diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 68337b9c..202ca128 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -21,7 +21,7 @@ load_synthetic_weekly_hospital_admissions, ) from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import AR1 +from pyrenew.latent import AR1, StepwiseTemporalProcess from pyrenew.latent.population_infections import PopulationInfections from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import NegativeBinomialNoise, PopulationCounts @@ -204,6 +204,80 @@ def he_model( return builder.build() +@pytest.fixture(scope="module") +def he_weekly_rt_model( + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """ + Build a PopulationInfections model with weekly-parameterized R(t). + + Same observation configuration as ``he_weekly_model`` (weekly hospital + admissions on the MMWR epiweek grid + daily ED visits with a day-of-week + effect), but R(t) is sampled weekly and broadcast to daily via + ``StepwiseTemporalProcess`` with calendar-week alignment. This mirrors the + production pyrenew-hew configuration. + + Parameters + ---------- + hosp_delay_pmf : jnp.ndarray + Infection-to-hospitalization delay PMF. + ed_delay_pmf : jnp.ndarray + Infection-to-ED-visit delay PMF. + ed_day_of_week_effects : jnp.ndarray + Day-of-week multipliers used in synthetic ED generation. + + Returns + ------- + MultiSignalModel + Built model ready for fitting. + """ + gen_int_pmf = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] + ) + + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week_start_dow=6, + ), + ) + + hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + aggregation_period=7, + reporting_schedule="regular", + period_end_dow=5, + ) + builder.add_observation(hospital_obs) + + ed_obs = PopulationCounts( + name="ed", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + builder.add_observation(ed_obs) + + return builder.build() + + @pytest.fixture(scope="module") def he_weekly_model( hosp_delay_pmf: jnp.ndarray, diff --git a/test/test_datagen_he_CA_126.py b/test/test_datagen_he_CA_126.py index ebbfe997..9ab12e3d 100644 --- a/test/test_datagen_he_CA_126.py +++ b/test/test_datagen_he_CA_126.py @@ -72,22 +72,22 @@ def test_uniform_effects(self): """Test that uniform effects leave values unchanged.""" values = np.array([10.0, 20.0, 30.0]) dow = np.ones(7) - result = apply_day_of_week_effects(values, dow, first_dow=0) + result = apply_day_of_week_effects(values, dow, first_day_dow=0) np.testing.assert_allclose(result, values) def test_known_pattern(self): """Test that known day-of-week effects are applied correctly.""" values = np.ones(7) * 100.0 dow = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.5]) - result = apply_day_of_week_effects(values, dow, first_dow=0) + result = apply_day_of_week_effects(values, dow, first_day_dow=0) np.testing.assert_allclose(result[5], 50.0) np.testing.assert_allclose(result[6], 50.0) def test_offset_start(self): - """Test that first_dow correctly offsets the pattern.""" + """Test that first_day_dow correctly offsets the pattern.""" values = np.ones(3) * 10.0 dow = np.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - result = apply_day_of_week_effects(values, dow, first_dow=6) + result = apply_day_of_week_effects(values, dow, first_day_dow=6) np.testing.assert_allclose(result[0], 10.0) np.testing.assert_allclose(result[1], 20.0) From e70c86d0da78278861cce2e763b41909d5d0e0a7 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 23 Apr 2026 13:17:30 -0400 Subject: [PATCH 11/17] refactor to WeekCycle --- .../tutorials/building_multisignal_models.qmd | 8 +- pyrenew/latent/temporal_processes.py | 44 +- pyrenew/model/pyrenew_builder.py | 51 +-- pyrenew/observation/base.py | 103 ++--- pyrenew/observation/count_observations.py | 419 +++++++++++------- pyrenew/time.py | 82 +++- test/conftest.py | 29 +- test/integration/conftest.py | 20 +- .../test_population_infections_he_weekly.py | 11 +- ...test_population_infections_he_weekly_rt.py | 381 ++++++++++++++++ test/test_observation_counts.py | 42 +- test/test_observation_validation.py | 148 +++---- test/test_population_infections.py | 3 +- test/test_pyrenew_builder.py | 55 ++- test/test_temporal_processes.py | 23 +- test/test_time.py | 49 ++ 16 files changed, 972 insertions(+), 496 deletions(-) create mode 100644 test/integration/test_population_infections_he_weekly_rt.py diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 5b97ea92..53ec13e6 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -80,6 +80,7 @@ from pyrenew.observation import ( MeasurementNoise, NegativeBinomialNoise, ) +from pyrenew.time import MMWR_WEEK ``` ## Overview @@ -261,12 +262,13 @@ weekly_baseline_rt_process = StepwiseTemporalProcess( inner=AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, # Sunday + week=MMWR_WEEK, ) ``` -Use `alignment="calendar_week"` when the weekly Rt blocks should align to a known weekday. -For MMWR-style weeks ending on Saturday, use `week_start_dow=6` for Sunday starts; the corresponding weekly count observation would use `period_end_dow=5`. +Use `alignment="calendar_week"` when the weekly Rt blocks should align to a calendar week. +Pass the same `WeekCycle` to every weekly component that must agree on the calendar: the `StepwiseTemporalProcess` above and any weekly `PopulationCounts` or `SubpopulationCounts` observation. +`pyrenew.time` exports `MMWR_WEEK` (Sunday-Saturday epiweeks) and `ISO_WEEK` (Monday-Sunday); use `WeekCycle(start_dow=k)` for any other convention. At sample or run time, pass the day of week for element 0 of the shared model axis: ```python diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 98b55d81..8fbe46fd 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -60,7 +60,7 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.process.randomwalk import RandomWalk as ProcessRandomWalk from pyrenew.randomvariable import DistributionalVariable -from pyrenew.time import validate_dow, weekly_to_daily +from pyrenew.time import WeekCycle, validate_dow, weekly_to_daily @runtime_checkable @@ -463,12 +463,13 @@ class StepwiseTemporalProcess(TemporalProcess): alignment How repeated blocks align to the output time axis. ``"model_index"`` starts blocks at output index 0. ``"calendar_week"`` aligns blocks - to calendar weeks using ``week_start_dow`` and the ``first_day_dow`` + to calendar weeks using ``week`` and the ``first_day_dow`` supplied to ``sample()``. - week_start_dow - Day of week on which weekly blocks begin when - ``alignment="calendar_week"`` (0=Monday, ..., 6=Sunday). Required - for calendar-week alignment. + week + Calendar-week anchor used when ``alignment="calendar_week"`` + (e.g., :data:`pyrenew.time.MMWR_WEEK`, + :data:`pyrenew.time.ISO_WEEK`). Required for calendar-week + alignment; must be ``None`` otherwise. Raises ------ @@ -484,7 +485,7 @@ def __init__( inner: TemporalProcess, step_size: int, alignment: Literal["model_index", "calendar_week"] = "model_index", - week_start_dow: int | None = None, + week: WeekCycle | None = None, ) -> None: """ Initialize stepwise temporal process. @@ -498,11 +499,11 @@ def __init__( alignment How repeated blocks align to the output time axis. ``"model_index"`` starts blocks at output index 0. ``"calendar_week"`` aligns - weekly blocks to ``week_start_dow`` using ``first_day_dow`` at + weekly blocks to ``week`` using ``first_day_dow`` at sample time. - week_start_dow - Day of week on which weekly blocks begin when - ``alignment="calendar_week"`` (0=Monday, ..., 6=Sunday). + week + Calendar-week anchor used when + ``alignment="calendar_week"``. Raises ------ @@ -523,26 +524,21 @@ def __init__( "calendar_week alignment requires step_size=7, " f"got step_size={step_size}" ) - if week_start_dow is None: - raise ValueError( - "week_start_dow is required when alignment='calendar_week'" - ) - validate_dow(week_start_dow, "week_start_dow") - elif week_start_dow is not None: - raise ValueError( - "week_start_dow is only used when alignment='calendar_week'" - ) + if week is None: + raise ValueError("week is required when alignment='calendar_week'") + elif week is not None: + raise ValueError("week is only used when alignment='calendar_week'") self.inner = inner self.step_size = step_size self.alignment = alignment - self.week_start_dow = week_start_dow + self.week = week def __repr__(self) -> str: """Return string representation.""" return ( f"StepwiseTemporalProcess(inner={self.inner!r}, " f"step_size={self.step_size}, alignment={self.alignment!r}, " - f"week_start_dow={self.week_start_dow!r})" + f"week={self.week!r})" ) def _resolve_n_coarse(self, n_timepoints: int, first_day_dow: int | None) -> int: @@ -563,7 +559,7 @@ def _resolve_n_coarse(self, n_timepoints: int, first_day_dow: int | None) -> int "alignment='calendar_week'" ) validate_dow(first_day_dow, "first_day_dow") - trim = (first_day_dow - self.week_start_dow) % 7 + trim = (first_day_dow - self.week.start_dow) % 7 return (n_timepoints + trim + 6) // 7 def sample( @@ -619,6 +615,6 @@ def sample( return jnp.repeat(coarse, repeats=self.step_size, axis=0)[:n_timepoints] return weekly_to_daily( coarse, - week_start_dow=self.week_start_dow, + week_start_dow=self.week.start_dow, output_data_first_dow=first_day_dow, )[:n_timepoints] diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index c6447d49..728db2a4 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -202,32 +202,28 @@ def _validate_coherence(self) -> None: Checks: - - All observations sharing the same ``aggregation_period > 1`` must - agree on ``period_end_dow``. + - All weekly observations must share a single + :class:`pyrenew.time.WeekCycle`. - Every temporal-process ``step_size`` must be a positive integer. - - Calendar-week-aligned temporal processes must have a - ``week_start_dow`` consistent with the ``period_end_dow`` of any - weekly observation: ``period_end_dow == (week_start_dow + 6) % 7``. + - Calendar-week-aligned temporal processes must share that + :class:`WeekCycle`. Raises ------ ValueError If any of the above checks fails. """ - # Intentionally permissive: parameter cadence and observation cadence - # are independent. - agg_by_dow: dict[int, set[int]] = {} - for name, obs in self.observations.items(): - P = getattr(obs, "aggregation_period", 1) - if P > 1: - agg_by_dow.setdefault(P, set()).add(getattr(obs, "period_end_dow")) - - for P, dows in agg_by_dow.items(): - if len(dows) > 1: - raise ValueError( - f"Observations with aggregation_period={P} must agree on " - f"period_end_dow; got values {sorted(dows)}" - ) + weekly_weeks = { + obs.week + for obs in self.observations.values() + if getattr(obs, "aggregation", "daily") == "weekly" + } + if len(weekly_weeks) > 1: + raise ValueError( + f"Weekly observations must share a single WeekCycle; " + f"got {sorted(weekly_weeks, key=lambda w: w.start_dow)}" + ) + obs_week = next(iter(weekly_weeks), None) temporal_processes = { name: value @@ -235,8 +231,6 @@ def _validate_coherence(self) -> None: if isinstance(value, TemporalProcess) } - weekly_period_end_dow = next(iter(agg_by_dow.get(7, set())), None) - for param_name, process in temporal_processes.items(): step_size = getattr(process, "step_size", 1) if not isinstance(step_size, int) or step_size < 1: @@ -246,17 +240,14 @@ def _validate_coherence(self) -> None: ) if ( getattr(process, "alignment", None) == "calendar_week" - and weekly_period_end_dow is not None + and obs_week is not None ): - week_start_dow = getattr(process, "week_start_dow", None) - expected_period_end_dow = (week_start_dow + 6) % 7 - if expected_period_end_dow != weekly_period_end_dow: + proc_week = getattr(process, "week", None) + if proc_week != obs_week: raise ValueError( - f"Temporal process '{param_name}' has " - f"week_start_dow={week_start_dow}, which implies " - f"weekly observations should end on " - f"period_end_dow={expected_period_end_dow}; got " - f"period_end_dow={weekly_period_end_dow}" + f"Temporal process '{param_name}' has week={proc_week!r}, " + f"which disagrees with the weekly observation " + f"week={obs_week!r}" ) def build(self) -> MultiSignalModel: diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 5abb076b..bdddf601 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -17,6 +17,7 @@ from pyrenew.arrayutils import require_shape from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.metaclass import RandomVariable +from pyrenew.time import WeekCycle class BaseObservationProcess(RandomVariable): @@ -196,114 +197,74 @@ def _validate_dow_effect( if jnp.any(dow_effect < 0): raise ValueError(f"{param_name} must have non-negative values") - def _validate_aggregation_params( + def _validate_week( self, - aggregation_period: int, - period_end_dow: int | None, + aggregation: str, + week: WeekCycle | None, ) -> None: """ - Validate temporal-aggregation constructor parameters. + Validate the ``(aggregation, week)`` pair. - Checks that ``aggregation_period`` is an integer in - ``{1, 7}``, and that ``period_end_dow`` is an integer in - ``{0, ..., 6}`` (0=Monday, 6=Sunday, ISO convention) when - ``aggregation_period == 7``. ``period_end_dow`` is ignored - when ``aggregation_period == 1``. + ``aggregation="weekly"`` requires a :class:`WeekCycle`; + ``aggregation="daily"`` ignores ``week``. Parameters ---------- - aggregation_period - Width of the reporting period in fundamental time units. - period_end_dow - Day-of-week index of each weekly period's final day, used as - the calendar anchor for weekly aggregation (e.g., 5 for MMWR - Sunday-Saturday epiweeks; 6 for ISO Monday-Sunday weeks). + aggregation + Observation reporting cadence; one of ``"daily"`` or + ``"weekly"``. + week + Calendar-week anchor; required iff + ``aggregation == "weekly"``. Raises ------ ValueError - If ``aggregation_period`` is not in ``{1, 7}``, or if - ``period_end_dow`` is missing or out of range when - ``aggregation_period == 7``. - - Notes - ----- - ``period_end_dow`` is a weekly-specific anchor; generalizing - beyond ``aggregation_period == 7`` will require a different - anchor representation. + If ``aggregation`` is unrecognized, or if + ``aggregation == "weekly"`` and ``week`` is ``None``. """ - if not isinstance(aggregation_period, int) or aggregation_period < 1: - raise ValueError( - "aggregation_period must be a positive integer, " - f"got {aggregation_period!r}" - ) - if aggregation_period not in (1, 7): + if aggregation not in ("daily", "weekly"): raise ValueError( - f"aggregation_period must be one of {{1, 7}}, got {aggregation_period}" + f"aggregation must be one of {{'daily', 'weekly'}}, got {aggregation!r}" ) - if aggregation_period == 7: - if period_end_dow is None: - raise ValueError( - "period_end_dow is required when aggregation_period == 7" - ) - if not isinstance(period_end_dow, int) or not (0 <= period_end_dow <= 6): - raise ValueError( - "period_end_dow must be an integer in {0, ..., 6} " - f"(0=Monday, 6=Sunday), got {period_end_dow!r}" - ) + if aggregation == "weekly" and week is None: + raise ValueError("week is required when aggregation == 'weekly'") def _compute_period_offset( self, first_day_dow: int | None, - aggregation_period: int, - period_end_dow: int | None, + week: WeekCycle | None, ) -> int: """ Compute the number of leading daily timepoints to trim so - that the daily axis aligns to whole aggregation periods. + that the daily axis aligns to whole weekly periods. Parameters ---------- first_day_dow Day-of-week index of element 0 of the daily axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period == 7``. - aggregation_period - Width of the reporting period. Must be in ``{1, 7}``. - period_end_dow - Day-of-week index of each period's final day. Required - when ``aggregation_period == 7``. + ``week`` is not ``None``. + week + Calendar-week anchor. ``None`` indicates daily + (non-aggregated) observations; in that case the + offset is ``0``. Returns ------- int - Trim offset in ``[0, aggregation_period)``. Returns - ``0`` when ``aggregation_period == 1``. + Trim offset in ``[0, 7)``. Returns ``0`` when ``week`` is ``None``. Raises ------ ValueError - If ``aggregation_period`` is not in ``{1, 7}``, or if - ``first_day_dow`` or ``period_end_dow`` is ``None`` when - ``aggregation_period == 7``. - - Notes - ----- - For ``aggregation_period == 7``: - ``(period_end_dow + 1 - first_day_dow) % 7``. + If ``week`` is provided but ``first_day_dow`` is ``None``. """ - if aggregation_period == 1: + if week is None: return 0 - if aggregation_period != 7: - raise ValueError( - f"aggregation_period must be one of {{1, 7}}, got {aggregation_period}" - ) - if first_day_dow is None or period_end_dow is None: - raise ValueError( - "first_day_dow and period_end_dow are both required " - "when aggregation_period == 7" - ) - return (period_end_dow + 1 - first_day_dow) % aggregation_period + if first_day_dow is None: + raise ValueError("first_day_dow is required when week is not None") + return (week.end_dow + 1 - first_day_dow) % 7 def _convolve_with_alignment( self, diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index c6e4b670..9228ba66 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -5,6 +5,8 @@ from __future__ import annotations +from typing import Literal + import jax import jax.numpy as jnp from jax.typing import ArrayLike @@ -14,7 +16,11 @@ from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise from pyrenew.observation.types import ObservationSample -from pyrenew.time import daily_to_weekly, get_sequential_day_of_week_indices +from pyrenew.time import ( + WeekCycle, + daily_to_weekly, + get_sequential_day_of_week_indices, +) class CountObservation(BaseObservationProcess): @@ -38,9 +44,9 @@ def __init__( noise: CountNoise, right_truncation_rv: RandomVariable | None = None, day_of_week_rv: RandomVariable | None = None, - aggregation_period: int = 1, - reporting_schedule: str = "regular", - period_end_dow: int | None = None, + aggregation: Literal["daily", "weekly"] = "daily", + reporting_schedule: Literal["regular", "irregular"] = "regular", + week: WeekCycle | None = None, ) -> None: """ Initialize count observation base. @@ -71,51 +77,63 @@ def __init__( overall predicted counts. When provided (along with ``first_day_dow`` at sample time), predicted counts are scaled by a periodic weekly pattern. - aggregation_period - Width of the observation reporting period in days. Must be in - ``{1, 7}``. ``1`` means no aggregation (daily observations). - This controls only the scale on which the count likelihood is - evaluated. It does not control how often the latent Rt temporal - process samples new parameters. + aggregation + Observation reporting cadence; one of ``"daily"`` or + ``"weekly"``. Controls only the scale on which the count + likelihood is evaluated; it does not control how often the + latent Rt temporal process samples new parameters. reporting_schedule Either ``"regular"`` (dense observation array, one entry per period, NaN for unobserved periods) or ``"irregular"`` (sparse observation array with user-supplied period-end time indices). - period_end_dow - Day-of-week index of each weekly period's final day (e.g., 5 - for MMWR Sunday-Saturday epiweeks; 6 for ISO Monday-Sunday - weeks). Required when ``aggregation_period == 7``; ignored - otherwise. Anchors the weekly likelihood: daily predictions - are bucketed into weeks ending on this day, then summed before - scoring. (0=Monday, 6=Sunday, ISO convention.) + week + Calendar-week anchor used for weekly aggregation (e.g., + :data:`pyrenew.time.MMWR_WEEK` for Sunday-Saturday + epiweeks, :data:`pyrenew.time.ISO_WEEK` for Monday-Sunday + weeks). Required when ``aggregation == "weekly"``; ignored + otherwise. Daily predictions are bucketed into weeks + according to ``week`` and summed before scoring. Raises ------ ValueError - If aggregation/reporting parameters are invalid, or if a - day-of-week effect is combined with ``aggregation_period > 1`` - (within-period structure is destroyed by aggregation). + If ``aggregation``, ``reporting_schedule``, or ``week`` are + invalid, or if a day-of-week effect is combined with + ``aggregation == "weekly"`` (within-period structure is + destroyed by aggregation). """ super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv) self.ascertainment_rate_rv = ascertainment_rate_rv self.noise = noise self.right_truncation_rv = right_truncation_rv self.day_of_week_rv = day_of_week_rv - self._validate_aggregation_params(aggregation_period, period_end_dow) + self._validate_week(aggregation, week) if reporting_schedule not in self._SUPPORTED_SCHEDULES: raise ValueError( f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, " f"got {reporting_schedule!r}" ) - if aggregation_period > 1 and day_of_week_rv is not None: + if aggregation == "weekly" and day_of_week_rv is not None: raise ValueError( - "day_of_week_rv cannot be combined with aggregation_period > 1; " + "day_of_week_rv cannot be combined with aggregation == 'weekly'; " "aggregation destroys within-period structure." ) - self.aggregation_period = aggregation_period + self.aggregation = aggregation self.reporting_schedule = reporting_schedule - self.period_end_dow = period_end_dow + self.week = week + + @property + def aggregation_period(self) -> int: + """ + Width of the observation reporting period in days. + + Returns + ------- + int + ``1`` for daily aggregation, ``7`` for weekly. + """ + return 7 if self.aggregation == "weekly" else 1 def validate(self) -> None: """ @@ -299,9 +317,9 @@ def _aggregate( """ Aggregate daily predicted counts to the observation reporting grid. - When ``aggregation_period == 1`` returns the input unchanged. + When ``aggregation == "daily"`` returns the input unchanged. Otherwise sums daily values over non-overlapping fixed-width - periods anchored by ``period_end_dow``, via + periods anchored by ``week``, via ``pyrenew.time.daily_to_weekly``. Works on both 1D ``(n_total,)`` and 2D ``(n_total, n_subpops)`` inputs. This aggregation is part of the observation likelihood path and is @@ -314,7 +332,7 @@ def _aggregate( first_day_dow Day-of-week index of element 0 of the daily axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period > 1``. + ``aggregation == "weekly"``. Returns ------- @@ -322,23 +340,190 @@ def _aggregate( Aggregated counts on the period grid; same trailing dimensions as ``predicted_daily``. Returns ``predicted_daily`` unchanged when - ``aggregation_period == 1``. + ``aggregation == "daily"``. Raises ------ ValueError - If ``aggregation_period > 1`` and ``first_day_dow`` is ``None``. + If ``aggregation == "weekly"`` and ``first_day_dow`` is ``None``. """ - if self.aggregation_period == 1: + if self.aggregation == "daily": return predicted_daily if first_day_dow is None: - raise ValueError("first_day_dow is required when aggregation_period > 1") + raise ValueError("first_day_dow is required when aggregation == 'weekly'") return daily_to_weekly( predicted_daily, input_data_first_dow=first_day_dow, - week_start_dow=(self.period_end_dow + 1) % 7, + week_start_dow=self.week.start_dow, ) + def _compute_predicted( + self, + infections: ArrayLike, + first_day_dow: int | None, + right_truncation_offset: int | None, + ) -> ArrayLike: + """ + Build the predicted counts on the reporting-period grid. + + Runs ascertainment and delay convolution, then optionally + applies the day-of-week multiplicative effect and + right-truncation adjustment, and aggregates to the reporting + grid. Emits ``predicted`` (and ``predicted_daily`` when + aggregating) as numpyro deterministic sites. + + Parameters + ---------- + infections + Infections from the latent process. Shape ``(n_total,)`` + for aggregate, ``(n_total, n_subpops)`` for subpopulation-level. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. + right_truncation_offset + If set (together with ``right_truncation_rv``), the + number of additional reporting days that have occurred + since the last observation. + + Returns + ------- + ArrayLike + Predicted counts on the reporting-period grid; same + trailing dimensions as ``infections``. Equal to + predicted-daily when ``aggregation == "daily"``. + + Raises + ------ + ValueError + If ``day_of_week_rv`` was set but ``first_day_dow`` is + ``None``. + """ + predicted_daily = self._predicted_obs(infections) + if self.day_of_week_rv is not None: + if first_day_dow is None: + raise ValueError( + "first_day_dow is required when day_of_week_rv is set." + ) + predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) + if self.right_truncation_rv is not None and right_truncation_offset is not None: + predicted_daily = self._apply_right_truncation( + predicted_daily, right_truncation_offset + ) + predicted = self._aggregate(predicted_daily, first_day_dow) + if self.aggregation == "weekly": + self._deterministic("predicted_daily", predicted_daily) + self._deterministic("predicted", predicted) + return predicted + + def _score_masked( + self, + predicted: ArrayLike, + obs: ArrayLike | None, + ) -> ArrayLike: + """ + Evaluate the masked likelihood on a dense period grid. + + Builds a boolean mask from the non-NaN positions of + ``predicted`` and (if provided) ``obs``, replaces NaN entries + with safe placeholder values, and delegates to + ``noise.sample`` with the mask. Shape-agnostic: works on 1D + period grids and 2D ``(n_periods, n_subpops)`` grids. + + Parameters + ---------- + predicted + Predicted counts on the reporting-period grid. NaN + entries mark the initialization period. + obs + Observed counts of the same shape as ``predicted``, or + ``None`` for prior predictive sampling. NaN entries mark + unobserved periods. + + Returns + ------- + ArrayLike + Sampled or conditioned counts from the noise model. + + Notes + ----- + JAX evaluates ``log_prob`` for every element regardless of + the mask; replacing NaN with finite placeholders prevents + NaN propagation in the trace while ``mask=False`` excludes + those entries from the likelihood sum. + """ + valid_pred = ~jnp.isnan(predicted) + if obs is not None: + valid_obs = ~jnp.isnan(obs) + mask = valid_pred & valid_obs + else: + mask = valid_pred + safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) + safe_obs = None + if obs is not None: + safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + return self.noise.sample( + name=self._sample_site_name("obs"), + predicted=safe_predicted, + obs=safe_obs, + mask=mask, + ) + + def _period_indices( + self, + period_end_times: ArrayLike, + first_day_dow: int | None, + ) -> jnp.ndarray: + """ + Convert daily-axis period-end indices to period-grid indices. + + For each daily-axis index ``t`` identifying the final day of + a reporting period, returns the position of that period in + the aggregated output. + + Parameters + ---------- + period_end_times + Daily-axis indices of each observed period's final day. + first_day_dow + Day-of-week index of element 0 of the daily axis. Required + when ``aggregation == "weekly"``. + + Returns + ------- + jnp.ndarray + Period-grid indices, one per entry in ``period_end_times``. + """ + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, self.week) + return (jnp.asarray(period_end_times) - offset - (P - 1)) // P + + def _n_periods(self, n_total: int, first_day_dow: int | None) -> int: + """ + Return the number of complete reporting periods in ``n_total`` days. + + Parameters + ---------- + n_total + Total number of daily time steps (``n_init + n_days_post_init``). + first_day_dow + Day-of-week index of element 0 of the daily axis. Required + when ``aggregation == "weekly"``. + + Returns + ------- + int + ``n_total`` when ``aggregation == "daily"``; otherwise the + number of complete weekly periods after trimming the + leading partial week. + """ + if self.aggregation == "daily": + return n_total + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, self.week) + return (n_total - offset) // P + class PopulationCounts(CountObservation): """ @@ -415,7 +600,7 @@ def validate_data( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period > 1`` so weekly observation periods can be + ``aggregation == "weekly"`` so weekly observation periods can be aligned to the shared daily model axis. **kwargs Additional keyword arguments (ignored). @@ -425,23 +610,15 @@ def validate_data( ValueError If obs length or period_end_times fail their respective checks, or if ``first_day_dow`` is missing when - ``aggregation_period > 1``. + ``aggregation == "weekly"``. """ - P = self.aggregation_period - if self.reporting_schedule == "regular": if obs is None: return - if P == 1: + if self.aggregation == "daily": self._validate_obs_dense(obs, n_total) return - if first_day_dow is None: - raise ValueError( - f"Observation '{self.name}': first_day_dow is required " - f"when aggregation_period == {P}" - ) - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - n_periods = (n_total - offset) // P + n_periods = self._n_periods(n_total, first_day_dow) obs = jnp.asarray(obs) if obs.ndim != 1: raise ValueError( @@ -457,13 +634,10 @@ def validate_data( if period_end_times is None: return - if P > 1 and first_day_dow is None: - raise ValueError( - f"Observation '{self.name}': first_day_dow is required " - f"when aggregation_period == {P}" - ) - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - self._validate_period_end_times(period_end_times, n_total, offset, P) + offset = self._compute_period_offset(first_day_dow, self.week) + self._validate_period_end_times( + period_end_times, n_total, offset, self.aggregation_period + ) if obs is not None: self._validate_shapes_match( obs, period_end_times, "obs", "period_end_times" @@ -481,7 +655,7 @@ def sample( Sample aggregated counts. Daily transforms (right-truncation, day-of-week) run on the - daily axis. When ``aggregation_period > 1`` the daily + daily axis. When ``aggregation == "weekly"`` the daily predictions are summed onto the reporting-period grid before the noise model. Likelihood path depends on ``reporting_schedule``: ``"regular"`` uses a dense-with-NaN @@ -513,7 +687,7 @@ def sample( Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when ``day_of_week_rv`` was set at construction or when - ``aggregation_period > 1``. This aligns observation-level + ``aggregation == "weekly"``. This aligns observation-level day-of-week effects or weekly aggregation to the shared daily model axis. period_end_times @@ -526,62 +700,24 @@ def sample( Named tuple with ``observed`` (sampled/conditioned counts) and ``predicted`` (predictions on the reporting-period grid; equal to daily predictions when - ``aggregation_period == 1``). + ``aggregation == "daily"``). """ - predicted_daily = self._predicted_obs(infections) - if self.day_of_week_rv is not None: - if first_day_dow is None: - raise ValueError( - "first_day_dow is required when day_of_week_rv is set." - ) - predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) - if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_daily = self._apply_right_truncation( - predicted_daily, right_truncation_offset - ) - - predicted = self._aggregate(predicted_daily, first_day_dow) - if self.aggregation_period > 1: - self._deterministic("predicted_daily", predicted_daily) - self._deterministic("predicted", predicted) + predicted = self._compute_predicted( + infections, first_day_dow, right_truncation_offset + ) if self.reporting_schedule == "regular": - valid_pred = ~jnp.isnan(predicted) - if obs is not None: - valid_obs = ~jnp.isnan(obs) - mask = valid_pred & valid_obs - else: - mask = valid_pred - - # JAX evaluates log_prob for all elements even when mask excludes - # them from the likelihood sum. Replace NaN with safe values to - # avoid NaN propagation; mask=False ensures they do not affect - # inference. - safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) - safe_obs = None - if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) - - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=safe_predicted, - obs=safe_obs, - mask=mask, - ) + observed = self._score_masked(predicted, obs) else: if period_end_times is None: raise ValueError( f"Observation '{self.name}': period_end_times is " f"required when reporting_schedule == 'irregular'" ) - P = self.aggregation_period - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - period_idx = (jnp.asarray(period_end_times) - offset - (P - 1)) // P - predicted_obs = predicted[period_idx] - + period_idx = self._period_indices(period_end_times, first_day_dow) observed = self.noise.sample( name=self._sample_site_name("obs"), - predicted=predicted_obs, + predicted=predicted[period_idx], obs=obs, ) @@ -665,7 +801,7 @@ def validate_data( first_day_dow Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when - ``aggregation_period > 1`` so weekly observation periods can be + ``aggregation == "weekly"`` so weekly observation periods can be aligned to the shared daily model axis. subpop_indices Subpopulation indices (0-indexed). For @@ -682,23 +818,15 @@ def validate_data( ValueError If any index array is out of bounds, any shape check fails, or ``first_day_dow`` is missing when - ``aggregation_period > 1``. + ``aggregation == "weekly"``. """ - P = self.aggregation_period - if subpop_indices is not None: self._validate_subpop_indices(subpop_indices, n_subpops) if self.reporting_schedule == "regular": if obs is None: return - if P > 1 and first_day_dow is None: - raise ValueError( - f"Observation '{self.name}': first_day_dow is required " - f"when aggregation_period == {P}" - ) - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - n_periods = n_total if P == 1 else (n_total - offset) // P + n_periods = self._n_periods(n_total, first_day_dow) obs = jnp.asarray(obs) if obs.ndim != 2: raise ValueError( @@ -723,13 +851,10 @@ def validate_data( if period_end_times is None: return - if P > 1 and first_day_dow is None: - raise ValueError( - f"Observation '{self.name}': first_day_dow is required " - f"when aggregation_period == {P}" - ) - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - self._validate_period_end_times(period_end_times, n_total, offset, P) + offset = self._compute_period_offset(first_day_dow, self.week) + self._validate_period_end_times( + period_end_times, n_total, offset, self.aggregation_period + ) if obs is not None: self._validate_shapes_match( obs, period_end_times, "obs", "period_end_times" @@ -755,7 +880,7 @@ def sample( Sample subpopulation-level counts. Daily transforms (right-truncation, day-of-week) run on the - daily axis. When ``aggregation_period > 1`` the daily + daily axis. When ``aggregation == "weekly"`` the daily predictions are summed onto the reporting-period grid before the noise model. Likelihood path depends on ``reporting_schedule``: ``"regular"`` selects the observed @@ -788,7 +913,7 @@ def sample( Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when ``day_of_week_rv`` was set at construction or when - ``aggregation_period > 1``. This aligns observation-level + ``aggregation == "weekly"``. This aligns observation-level day-of-week effects or weekly aggregation to the shared daily model axis. period_end_times @@ -808,69 +933,27 @@ def sample( Named tuple with ``observed`` (sampled/conditioned counts) and ``predicted`` (predictions on the reporting-period grid, shape ``(n_periods, n_subpops)``; equal to daily - predictions when ``aggregation_period == 1``). + predictions when ``aggregation == "daily"``). """ if subpop_indices is None: raise ValueError(f"Observation '{self.name}': subpop_indices is required.") - predicted_daily = self._predicted_obs(infections) - if self.day_of_week_rv is not None: - if first_day_dow is None: - raise ValueError( - "first_day_dow is required when day_of_week_rv is set." - ) - predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) - if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_daily = self._apply_right_truncation( - predicted_daily, right_truncation_offset - ) - - predicted = self._aggregate(predicted_daily, first_day_dow) - if self.aggregation_period > 1: - self._deterministic("predicted_daily", predicted_daily) - self._deterministic("predicted", predicted) + predicted = self._compute_predicted( + infections, first_day_dow, right_truncation_offset + ) if self.reporting_schedule == "regular": - predicted_selected = predicted[:, subpop_indices] - - valid_pred = ~jnp.isnan(predicted_selected) - if obs is not None: - valid_obs = ~jnp.isnan(obs) - mask = valid_pred & valid_obs - else: - mask = valid_pred - - # JAX evaluates log_prob for all elements even when mask excludes - # them from the likelihood sum. Replace NaN with safe values to - # avoid NaN propagation; mask=False ensures they do not affect - # inference. - safe_predicted = jnp.where( - jnp.isnan(predicted_selected), 1.0, predicted_selected - ) - safe_obs = None - if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) - - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=safe_predicted, - obs=safe_obs, - mask=mask, - ) + observed = self._score_masked(predicted[:, subpop_indices], obs) else: if period_end_times is None: raise ValueError( f"Observation '{self.name}': period_end_times is " f"required when reporting_schedule == 'irregular'" ) - P = self.aggregation_period - offset = self._compute_period_offset(first_day_dow, P, self.period_end_dow) - period_idx = (jnp.asarray(period_end_times) - offset - (P - 1)) // P - predicted_obs = predicted[period_idx, subpop_indices] - + period_idx = self._period_indices(period_end_times, first_day_dow) observed = self.noise.sample( name=self._sample_site_name("obs"), - predicted=predicted_obs, + predicted=predicted[period_idx, subpop_indices], obs=obs, ) diff --git a/pyrenew/time.py b/pyrenew/time.py index 54021917..80ee0627 100644 --- a/pyrenew/time.py +++ b/pyrenew/time.py @@ -7,31 +7,28 @@ Calendar concepts used elsewhere in pyrenew ------------------------------------------- -Three day-of-week values appear across the codebase. They name -distinct things; do not conflate them. +Two calendar concepts appear across the codebase. They name distinct +things; do not conflate them. - ``first_day_dow`` — day-of-week of element 0 of a model's shared daily time axis. Sample-time data fact (depends on the dataset being fit). Threaded through ``MultiSignalModel.sample`` and the ``TemporalProcess.sample`` protocol. -- ``period_end_dow`` — day-of-week on which each weekly observation - period ends. Construction-time modeling choice on - ``CountObservation``. MMWR convention: 5 (Saturday). ISO - convention: 6 (Sunday). +- :class:`WeekCycle` — a 7-day calendar cycle identified by its + ``start_dow``. Construction-time modeling choice, shared by every + model component that must agree on the same calendar week + (weekly :class:`CountObservation` and calendar-week-aligned + :class:`StepwiseTemporalProcess`). The module-level constants + :data:`MMWR_WEEK` (Sun-Sat) and :data:`ISO_WEEK` (Mon-Sun) cover + the common conventions; custom cycles use ``WeekCycle(start_dow=k)``. -- ``week_start_dow`` — day-of-week on which weekly Rt blocks begin. - Construction-time modeling choice on ``StepwiseTemporalProcess`` - with ``alignment="calendar_week"``. Also the parameter name used by - ``daily_to_weekly`` / ``weekly_to_daily`` in this module. +``PyrenewBuilder._validate_coherence`` enforces that all weekly +observations share one ``WeekCycle`` and that any calendar-week-aligned +temporal process uses the same cycle. -The two anchors are complements: a weekly observation period that ends -on day X is contained in the calendar week that starts on day -``(X + 1) % 7``. ``PyrenewBuilder._validate_coherence`` enforces -``period_end_dow == (week_start_dow + 6) % 7`` when both are present. - -Worked example: ``first_day_dow=3`` (Thursday), ``week_start_dow=6`` -(Sunday), 17-day axis:: +Worked example: ``first_day_dow=3`` (Thursday), ``week=MMWR_WEEK`` +(Sunday start), 17-day axis:: model index: 0 1 2 | 3 4 5 6 7 8 9 | 10 11 12 13 14 15 16 weekday: Th Fr Sa | Su Mo Tu We Th Fr Sa | Su Mo Tu We Th Fr Sa @@ -41,14 +38,15 @@ each subsequent block of 7 days falls in one full Sunday-Saturday week and shares one Rt sample. -The conversion between the model-axis fact and these calendar anchors -happens at call sites (e.g., ``daily_to_weekly(daily_predicted, -input_data_first_dow=first_day_dow, week_start_dow=...)``); functions -in this module use locally scoped role names like -``input_data_first_dow`` and ``output_data_first_dow``. +The low-level functions :func:`daily_to_weekly` and +:func:`weekly_to_daily` take integer day-of-week parameters +(``input_data_first_dow``, ``output_data_first_dow``, +``week_start_dow``) directly so that callers outside pyrenew can use +them without the :class:`WeekCycle` abstraction. """ import datetime as dt +from dataclasses import dataclass import jax.numpy as jnp import numpy as np @@ -96,6 +94,46 @@ def validate_dow(day_of_week: int, variable_name: str) -> None: return None +@dataclass(frozen=True) +class WeekCycle: + """ + A 7-day calendar cycle identified by its start day. + + Value object that names a single weekly convention (MMWR, + ISO, or custom) and is shared by any model component that + must agree on it. Equality is structural, so a builder-level + coherence check reduces to ``a == b``. + + Attributes + ---------- + start_dow + Day-of-week on which the cycle begins + (0=Monday, 6=Sunday, ISO convention). + """ + + start_dow: int + + def __post_init__(self) -> None: + """Validate ``start_dow`` via :func:`validate_dow`.""" + validate_dow(self.start_dow, "start_dow") + + @property + def end_dow(self) -> int: + """ + Day-of-week on which the cycle ends. + + Returns + ------- + int + ``(start_dow + 6) % 7``. + """ + return (self.start_dow + 6) % 7 + + +MMWR_WEEK: WeekCycle = WeekCycle(start_dow=6) +ISO_WEEK: WeekCycle = WeekCycle(start_dow=0) + + def get_sequential_day_of_week_indices( first_day_dow: int, n_timepoints: int ) -> jnp.ndarray: diff --git a/test/conftest.py b/test/conftest.py index d41bf2c9..197f6d7d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -38,6 +38,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable +from pyrenew.time import MMWR_WEEK # ============================================================================= # PMF Fixtures @@ -281,8 +282,7 @@ def weekly_regular_counts(simple_delay_pmf): """ PopulationCounts with weekly aggregation and regular (dense) reporting. - Reporting periods end on Saturdays (``period_end_dow=5``), matching the - MMWR epiweek convention. + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). Returns ------- @@ -294,9 +294,9 @@ def weekly_regular_counts(simple_delay_pmf): ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=5, + week=MMWR_WEEK, ) @@ -305,8 +305,7 @@ def weekly_irregular_counts(simple_delay_pmf): """ PopulationCounts with weekly aggregation and irregular (sparse) reporting. - Reporting periods end on Saturdays (``period_end_dow=5``), matching the - MMWR epiweek convention. + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). Returns ------- @@ -318,9 +317,9 @@ def weekly_irregular_counts(simple_delay_pmf): ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="irregular", - period_end_dow=5, + week=MMWR_WEEK, ) @@ -348,8 +347,7 @@ def weekly_regular_subpop_counts(simple_delay_pmf): """ SubpopulationCounts with weekly aggregation and regular (dense) reporting. - Reporting periods end on Saturdays (``period_end_dow=5``), matching the - MMWR epiweek convention. + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). Returns ------- @@ -361,9 +359,9 @@ def weekly_regular_subpop_counts(simple_delay_pmf): ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=5, + week=MMWR_WEEK, ) @@ -372,8 +370,7 @@ def weekly_irregular_subpop_counts(simple_delay_pmf): """ SubpopulationCounts with weekly aggregation and irregular (sparse) reporting. - Reporting periods end on Saturdays (``period_end_dow=5``), matching the - MMWR epiweek convention. + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). Returns ------- @@ -385,9 +382,9 @@ def weekly_irregular_subpop_counts(simple_delay_pmf): ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="irregular", - period_end_dow=5, + week=MMWR_WEEK, ) diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 202ca128..21049688 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -26,6 +26,7 @@ from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import NegativeBinomialNoise, PopulationCounts from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import MMWR_WEEK @pytest.fixture(scope="module") @@ -247,7 +248,7 @@ def he_weekly_rt_model( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), ) @@ -258,9 +259,9 @@ def he_weekly_rt_model( noise=NegativeBinomialNoise( DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) ), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=5, + week=MMWR_WEEK, ) builder.add_observation(hospital_obs) @@ -287,10 +288,11 @@ def he_weekly_model( """ Build a PopulationInfections model with WEEKLY hospital + DAILY ED observations. - The hospital observation is aggregated to MMWR epiweeks (Sunday-Saturday, - anchored by ``period_end_dow=5``); the ED observation stays daily with a - day-of-week effect. R(t) is parametrized at the finest observation - cadence (daily) per the coherence rules for mixed-cadence models. + The hospital observation is aggregated to MMWR epiweeks + (Sunday-Saturday, via ``MMWR_WEEK``); the ED observation stays + daily with a day-of-week effect. R(t) is parametrized at the + finest observation cadence (daily) per the coherence rules for + mixed-cadence models. Parameters ---------- @@ -326,9 +328,9 @@ def he_weekly_model( noise=NegativeBinomialNoise( DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) ), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=5, + week=MMWR_WEEK, ) builder.add_observation(hospital_obs) diff --git a/test/integration/test_population_infections_he_weekly.py b/test/integration/test_population_infections_he_weekly.py index 9872059d..01888bd7 100644 --- a/test/integration/test_population_infections_he_weekly.py +++ b/test/integration/test_population_infections_he_weekly.py @@ -18,6 +18,7 @@ import pytest from pyrenew.model import MultiSignalModel +from pyrenew.time import MMWR_WEEK pytestmark = pytest.mark.integration @@ -61,7 +62,7 @@ def _build_hospital_obs_on_period_grid( hosp = model.observations["hospital"] n_init = model.latent.n_initialization_points n_total = n_init + N_DAYS_FIT - offset = (hosp.period_end_dow + 1 - first_day_dow) % hosp.aggregation_period + offset = hosp._compute_period_offset(first_day_dow, hosp.week) n_periods = (n_total - offset) // hosp.aggregation_period n_pre = n_periods - len(weekly_values) return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) @@ -108,9 +109,9 @@ def test_hospital_is_weekly_regular( Built model. """ h = he_weekly_model.observations["hospital"] - assert h.aggregation_period == 7 + assert h.aggregation == "weekly" assert h.reporting_schedule == "regular" - assert h.period_end_dow == 5 + assert h.week == MMWR_WEEK def test_ed_stays_daily( self, @@ -125,7 +126,7 @@ def test_ed_stays_daily( Built model. """ ed = he_weekly_model.observations["ed"] - assert ed.aggregation_period == 1 + assert ed.aggregation == "daily" assert ed.day_of_week_rv is not None def test_weekly_obs_alignment( @@ -154,7 +155,7 @@ def test_weekly_obs_alignment( n_init = he_weekly_model.latent.n_initialization_points n_total = n_init + N_DAYS_FIT hosp = he_weekly_model.observations["hospital"] - offset = (hosp.period_end_dow + 1 - first_day_dow) % hosp.aggregation_period + offset = hosp._compute_period_offset(first_day_dow, hosp.week) n_periods = (n_total - offset) // hosp.aggregation_period assert hosp_obs.shape == (n_periods,) diff --git a/test/integration/test_population_infections_he_weekly_rt.py b/test/integration/test_population_infections_he_weekly_rt.py new file mode 100644 index 00000000..0cb24224 --- /dev/null +++ b/test/integration/test_population_infections_he_weekly_rt.py @@ -0,0 +1,381 @@ +""" +Integration test: PopulationInfections H+E model with WEEKLY R(t). + +Mirrors ``test_population_infections_he_weekly`` but parameterizes R(t) +weekly via ``StepwiseTemporalProcess(step_size=7, alignment="calendar_week")`` +and broadcasts to the daily renewal axis. This is the production +pyrenew-hew configuration: weekly hospital admissions + daily ED visits + +weekly calendar-aligned R(t). +""" + +from __future__ import annotations + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import numpyro +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +# Day 0 of the synthetic data is 2023-11-05, a Sunday (ISO dow = 6). +OBS_START_DOW = 6 +WEEK_START_DOW = 6 + + +def _build_hospital_obs_on_period_grid( + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """ + Build a dense weekly-observation array on the model's period grid. + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + weekly_values : jnp.ndarray + Observed weekly hospital admissions, one per MMWR epiweek. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + jnp.ndarray + Dense array of shape ``(n_periods,)`` with NaN for unobserved + periods and observed counts for periods covered by + ``weekly_values``. + """ + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = hosp._compute_period_offset(first_day_dow, hosp.week) + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) + + +def _expected_n_coarse(model: MultiSignalModel, first_day_dow: int) -> int: + """ + Expected number of coarse R(t) samples for calendar-week alignment. + + Mirrors ``StepwiseTemporalProcess._resolve_n_coarse`` for ``step_size=7`` + and ``alignment="calendar_week"``. + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + int + Number of weekly Rt samples covering the daily model axis. + """ + n_total = model.latent.n_initialization_points + N_DAYS_FIT + trim = (first_day_dow - WEEK_START_DOW) % 7 + return (n_total + trim + 6) // 7 + + +class TestPriorPredictiveStructure: + """Verify the weekly-Rt graph records a coarse trajectory at the right shape.""" + + def test_coarse_rt_recorded( + self, + he_weekly_rt_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> None: + """ + Single prior-predictive sample exposes ``log_rt_single_coarse`` at the + expected coarse length and a daily-length broadcast Rt. + + Parameters + ---------- + he_weekly_rt_model : MultiSignalModel + Built model with calendar-aligned weekly Rt. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + """ + first_day_dow = he_weekly_rt_model.compute_first_day_dow(OBS_START_DOW) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_rt_model, weekly_values, first_day_dow + ) + ed_obs = he_weekly_rt_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + population_size = float(weekly_hosp["pop"][0]) + + with numpyro.handlers.seed(rng_seed=0): + with numpyro.handlers.trace() as trace: + he_weekly_rt_model.sample( + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + first_day_dow=first_day_dow, + hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, + ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + ) + + n_total = he_weekly_rt_model.latent.n_initialization_points + N_DAYS_FIT + n_coarse = _expected_n_coarse(he_weekly_rt_model, first_day_dow) + + coarse = trace["log_rt_single_coarse"]["value"] + daily = trace["PopulationInfections::log_rt_single"]["value"] + + assert coarse.shape == (n_coarse, 1) + assert daily.shape == (n_total, 1) + assert n_coarse < n_total + + # Each block of 7 daily values past the leading partial week should be + # constant (the calendar-week broadcast invariant). + partial_len = (WEEK_START_DOW - first_day_dow) % 7 + if partial_len > 0: + assert jnp.allclose(daily[:partial_len], daily[0]) + first_full = partial_len + assert jnp.allclose(daily[first_full : first_full + 7], daily[first_full]) + + +class TestModelFit: + """Fit the weekly-Rt H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( + self, + he_weekly_rt_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """ + Fit the weekly-Rt H+E model to synthetic data via MCMC. + + Parameters + ---------- + he_weekly_rt_model : MultiSignalModel + Built model with calendar-aligned weekly Rt. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + + Returns + ------- + MultiSignalModel + Model with MCMC results attached. + """ + first_day_dow = he_weekly_rt_model.compute_first_day_dow(OBS_START_DOW) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_rt_model, weekly_values, first_day_dow + ) + + ed_obs = he_weekly_rt_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(weekly_hosp["pop"][0]) + + he_weekly_rt_model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + first_day_dow=first_day_dow, + hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, + ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + ) + + samples = he_weekly_rt_model.mcmc.get_samples() + jax.block_until_ready(samples) + return he_weekly_rt_model + + @pytest.fixture(scope="class") + def posterior_dt( + self, + fitted_model: MultiSignalModel, + ): + """ + Convert MCMC samples to an ArviZ DataTree, trimming the init period. + + Parameters + ---------- + fitted_model : MultiSignalModel + Model with MCMC results. + + Returns + ------- + xarray.DataTree + ArviZ DataTree with posterior group. + """ + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "log_rt_single_coarse": ["coarse_week", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): + """ + Trim the initialization period from the ``time`` dimension only. + + Parameters + ---------- + ds + Dataset to trim. + + Returns + ------- + xarray.Dataset + Dataset with ``time`` sliced to ``[n_init:]``; other dims + pass through unchanged. + """ + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """ + Check that core parameters have acceptable Rhat and ESS. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_coarse_rt_posterior_shape( + self, + fitted_model: MultiSignalModel, + posterior_dt, + ) -> None: + """ + Check the coarse Rt site lives on the weekly cadence in the posterior. + + Parameters + ---------- + fitted_model : MultiSignalModel + Fitted model. + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + first_day_dow = fitted_model.compute_first_day_dow(OBS_START_DOW) + n_coarse = _expected_n_coarse(fitted_model, first_day_dow) + + coarse = posterior_dt.posterior["log_rt_single_coarse"] + assert coarse.sizes["coarse_week"] == n_coarse + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """ + Check that the 90% credible interval for R(t) covers the true value + for at least 80% of time points. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_infections : pl.DataFrame + True R(t) trajectory. + """ + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + assert coverage >= 0.80, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 80%" + ) + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """ + Check that posterior median IHR and IEDR are within a factor + of 5 of the true values. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + true_params : dict + Ground-truth parameter dictionary. + """ + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index ad1cf940..62f6e6f0 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -15,6 +15,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import MMWR_WEEK from test.test_helpers import create_mock_infections @@ -890,41 +891,46 @@ def _make(self, simple_delay_pmf, **kwargs): ) def test_default_construction_is_daily_regular(self, simple_delay_pmf): - """Default constructor yields aggregation_period=1, reporting_schedule='regular'.""" + """Default constructor yields aggregation='daily', reporting_schedule='regular'.""" process = self._make(simple_delay_pmf) - assert process.aggregation_period == 1 + assert process.aggregation == "daily" assert process.reporting_schedule == "regular" - assert process.period_end_dow is None + assert process.week is None - def test_weekly_requires_period_end_dow(self, simple_delay_pmf): - """aggregation_period=7 without period_end_dow must raise.""" - with pytest.raises(ValueError, match="period_end_dow is required"): - self._make(simple_delay_pmf, aggregation_period=7) + def test_weekly_requires_week(self, simple_delay_pmf): + """aggregation='weekly' without week must raise.""" + with pytest.raises(ValueError, match="week is required"): + self._make(simple_delay_pmf, aggregation="weekly") - def test_weekly_with_saturday_anchor_constructs(self, simple_delay_pmf): - """aggregation_period=7, period_end_dow=5 (Saturday) is valid.""" - process = self._make(simple_delay_pmf, aggregation_period=7, period_end_dow=5) - assert process.aggregation_period == 7 - assert process.period_end_dow == 5 + def test_weekly_with_mmwr_anchor_constructs(self, simple_delay_pmf): + """aggregation='weekly' with MMWR_WEEK is valid.""" + process = self._make(simple_delay_pmf, aggregation="weekly", week=MMWR_WEEK) + assert process.aggregation == "weekly" + assert process.week == MMWR_WEEK def test_dow_effect_with_weekly_aggregation_raises(self, simple_delay_pmf): - """day_of_week_rv cannot be combined with aggregation_period > 1.""" + """day_of_week_rv cannot be combined with aggregation='weekly'.""" with pytest.raises(ValueError, match="day_of_week_rv cannot be combined"): self._make( simple_delay_pmf, - aggregation_period=7, - period_end_dow=5, + aggregation="weekly", + week=MMWR_WEEK, day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), ) def test_dow_effect_with_daily_aggregation_allowed(self, simple_delay_pmf): - """day_of_week_rv remains valid for aggregation_period=1.""" + """day_of_week_rv remains valid for aggregation='daily'.""" process = self._make( simple_delay_pmf, day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), ) assert process.day_of_week_rv is not None + def test_unknown_aggregation_raises(self, simple_delay_pmf): + """aggregation must be 'daily' or 'weekly'.""" + with pytest.raises(ValueError, match="aggregation must be one of"): + self._make(simple_delay_pmf, aggregation="monthly") + def test_unknown_reporting_schedule_raises(self, simple_delay_pmf): """reporting_schedule must be 'regular' or 'irregular'.""" with pytest.raises(ValueError, match="reporting_schedule must be one of"): @@ -1134,10 +1140,10 @@ def test_daily_irregular_period_indexing(self, daily_irregular_counts): assert result.observed.shape == (3,) def test_aggregate_helper_missing_first_day_dow_raises(self, weekly_regular_counts): - """_aggregate raises when aggregation_period > 1 and first_day_dow is None.""" + """_aggregate raises when aggregation == 'weekly' and first_day_dow is None.""" predicted_daily = jnp.ones(28) with pytest.raises( - ValueError, match="first_day_dow is required when aggregation_period > 1" + ValueError, match="first_day_dow is required when aggregation == 'weekly'" ): weekly_regular_counts._aggregate(predicted_daily, first_day_dow=None) diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index cbd47906..d08d7831 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -19,7 +19,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable -from pyrenew.time import daily_to_weekly +from pyrenew.time import MMWR_WEEK, WeekCycle, daily_to_weekly # --------------------------------------------------------------------------- # Helpers – minimal concrete subclass of MeasurementObservation for testing @@ -414,13 +414,13 @@ def test_2d_obs_weekly_raises(self): ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.3, 0.5, 0.2])), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=5, + week=MMWR_WEEK, ) n_total = 35 first_day_dow = 6 - offset = (proc.period_end_dow + 1 - first_day_dow) % proc.aggregation_period + offset = proc._compute_period_offset(first_day_dow, proc.week) n_periods = (n_total - offset) // proc.aggregation_period obs = jnp.ones((n_periods, 1)) with pytest.raises(ValueError, match="obs must be 1D"): @@ -654,61 +654,36 @@ def test_extra_kwargs_ignored(self, measurements_proc): # =================================================================== -# _validate_aggregation_params +# _validate_week # =================================================================== -class TestValidateAggregationParams: - """Tests for BaseObservationProcess._validate_aggregation_params.""" +class TestValidateWeek: + """Tests for BaseObservationProcess._validate_week.""" - def test_p1_no_anchor_passes(self, counts_proc): - """P=1 with period_end_dow=None should not raise.""" - counts_proc._validate_aggregation_params(1, None) + def test_daily_with_no_week_passes(self, counts_proc): + """aggregation='daily' with week=None should not raise.""" + counts_proc._validate_week("daily", None) - def test_p1_anchor_ignored(self, counts_proc): - """P=1 with any period_end_dow value should not raise.""" - counts_proc._validate_aggregation_params(1, 5) - counts_proc._validate_aggregation_params(1, 99) + def test_daily_ignores_week(self, counts_proc): + """aggregation='daily' ignores any supplied WeekCycle.""" + counts_proc._validate_week("daily", MMWR_WEEK) - def test_p7_valid_anchor_passes(self, counts_proc): - """P=7 with period_end_dow in {0, ..., 6} should not raise.""" - for dow in range(7): - counts_proc._validate_aggregation_params(7, dow) + def test_weekly_with_week_passes(self, counts_proc): + """aggregation='weekly' with any WeekCycle should not raise.""" + counts_proc._validate_week("weekly", MMWR_WEEK) + counts_proc._validate_week("weekly", WeekCycle(start_dow=0)) - def test_p7_missing_anchor_raises(self, counts_proc): - """P=7 with period_end_dow=None should raise.""" - with pytest.raises(ValueError, match="period_end_dow is required"): - counts_proc._validate_aggregation_params(7, None) + def test_weekly_missing_week_raises(self, counts_proc): + """aggregation='weekly' with week=None should raise.""" + with pytest.raises(ValueError, match="week is required"): + counts_proc._validate_week("weekly", None) - def test_p7_negative_anchor_raises(self, counts_proc): - """P=7 with period_end_dow<0 should raise.""" - with pytest.raises(ValueError, match="integer in"): - counts_proc._validate_aggregation_params(7, -1) - - def test_p7_too_large_anchor_raises(self, counts_proc): - """P=7 with period_end_dow>6 should raise.""" - with pytest.raises(ValueError, match="integer in"): - counts_proc._validate_aggregation_params(7, 7) - - def test_zero_period_raises(self, counts_proc): - """aggregation_period=0 should raise.""" - with pytest.raises(ValueError, match="positive integer"): - counts_proc._validate_aggregation_params(0, None) - - def test_negative_period_raises(self, counts_proc): - """Negative aggregation_period should raise.""" - with pytest.raises(ValueError, match="positive integer"): - counts_proc._validate_aggregation_params(-3, None) - - def test_unsupported_period_raises(self, counts_proc): - """aggregation_period not in {1, 7} should raise.""" - with pytest.raises(ValueError, match=r"one of \{1, 7\}"): - counts_proc._validate_aggregation_params(14, 5) - - def test_float_period_raises(self, counts_proc): - """Non-integer aggregation_period should raise.""" - with pytest.raises(ValueError, match="positive integer"): - counts_proc._validate_aggregation_params(7.0, 5) + @pytest.mark.parametrize("bad", ["monthly", "biweekly", "Weekly", ""]) + def test_unknown_aggregation_raises(self, counts_proc, bad): + """Unrecognized aggregation strings should raise.""" + with pytest.raises(ValueError, match="aggregation must be one of"): + counts_proc._validate_week(bad, MMWR_WEEK) # =================================================================== @@ -719,64 +694,57 @@ def test_float_period_raises(self, counts_proc): class TestComputePeriodOffset: """Tests for BaseObservationProcess._compute_period_offset.""" - def test_p1_returns_zero(self, counts_proc): - """P=1 always returns 0 regardless of dow arguments.""" - assert counts_proc._compute_period_offset(None, 1, None) == 0 - assert counts_proc._compute_period_offset(0, 1, 5) == 0 - assert counts_proc._compute_period_offset(6, 1, None) == 0 + def test_daily_returns_zero(self, counts_proc): + """week=None always returns 0 regardless of first_day_dow.""" + assert counts_proc._compute_period_offset(None, None) == 0 + assert counts_proc._compute_period_offset(0, None) == 0 + assert counts_proc._compute_period_offset(6, None) == 0 - def test_p7_mmwr_aligned_start(self, counts_proc): - """Daily axis starting Sunday with Saturday end => offset 0.""" - assert counts_proc._compute_period_offset(6, 7, 5) == 0 + def test_mmwr_aligned_start(self, counts_proc): + """Daily axis starting Sunday under MMWR_WEEK => offset 0.""" + assert counts_proc._compute_period_offset(6, MMWR_WEEK) == 0 - def test_p7_monday_start_saturday_end(self, counts_proc): - """Daily axis starting Monday with Saturday end => offset 6.""" - assert counts_proc._compute_period_offset(0, 7, 5) == 6 + def test_monday_start_mmwr(self, counts_proc): + """Daily axis starting Monday under MMWR_WEEK => offset 6.""" + assert counts_proc._compute_period_offset(0, MMWR_WEEK) == 6 - def test_p7_saturday_start_saturday_end(self, counts_proc): - """Daily axis starting Saturday with Saturday end => offset 1.""" - assert counts_proc._compute_period_offset(5, 7, 5) == 1 + def test_saturday_start_mmwr(self, counts_proc): + """Daily axis starting Saturday under MMWR_WEEK => offset 1.""" + assert counts_proc._compute_period_offset(5, MMWR_WEEK) == 1 - def test_p7_iso_week_alignment(self, counts_proc): - """Daily axis starting Thursday with Sunday end (ISO) => offset 4.""" - assert counts_proc._compute_period_offset(3, 7, 6) == 4 + def test_iso_week_alignment(self, counts_proc): + """Daily axis starting Thursday under ISO week (Mon start) => offset 4.""" + assert counts_proc._compute_period_offset(3, WeekCycle(start_dow=0)) == 4 - def test_p7_offset_always_in_range(self, counts_proc): - """P=7 offset is always in [0, 7) for any valid dow combination.""" + def test_offset_always_in_range(self, counts_proc): + """Offset is always in [0, 7) for any valid (first_day_dow, week).""" for first in range(7): - for end in range(7): - offset = counts_proc._compute_period_offset(first, 7, end) + for start in range(7): + offset = counts_proc._compute_period_offset( + first, WeekCycle(start_dow=start) + ) assert 0 <= offset < 7 - def test_p7_missing_first_day_dow_raises(self, counts_proc): - """P=7 with first_day_dow=None should raise.""" - with pytest.raises(ValueError, match="both required"): - counts_proc._compute_period_offset(None, 7, 5) - - def test_p7_missing_period_end_dow_raises(self, counts_proc): - """P=7 with period_end_dow=None should raise.""" - with pytest.raises(ValueError, match="both required"): - counts_proc._compute_period_offset(0, 7, None) - - def test_unsupported_period_raises(self, counts_proc): - """P not in {1, 7} should raise.""" - with pytest.raises(ValueError, match=r"one of \{1, 7\}"): - counts_proc._compute_period_offset(0, 14, 5) + def test_missing_first_day_dow_raises(self, counts_proc): + """week set with first_day_dow=None should raise.""" + with pytest.raises(ValueError, match="first_day_dow is required"): + counts_proc._compute_period_offset(None, MMWR_WEEK) def test_offset_agrees_with_daily_to_weekly(self, counts_proc): """ Offset from _compute_period_offset selects the same leading days that daily_to_weekly trims internally for the first - complete period, for every (first_day_dow, period_end_dow). + complete period, for every (first_day_dow, week) pair. """ daily = jnp.arange(21.0) for first in range(7): - for end in range(7): - offset = counts_proc._compute_period_offset(first, 7, end) + for start in range(7): + week = WeekCycle(start_dow=start) + offset = counts_proc._compute_period_offset(first, week) weekly = daily_to_weekly( daily, input_data_first_dow=first, - week_start_dow=(end + 1) % 7, + week_start_dow=start, ) expected_first_week = float(jnp.sum(daily[offset : offset + 7])) assert float(weekly[0]) == expected_first_week diff --git a/test/test_population_infections.py b/test/test_population_infections.py index 2c9c9683..cbe6febc 100644 --- a/test/test_population_infections.py +++ b/test/test_population_infections.py @@ -9,6 +9,7 @@ from pyrenew.deterministic import DeterministicVariable from pyrenew.latent import AR1, RandomWalk, StepwiseTemporalProcess from pyrenew.latent.population_infections import PopulationInfections +from pyrenew.time import MMWR_WEEK class TestPopulationInfectionsSample: @@ -136,7 +137,7 @@ def test_first_day_dow_reaches_calendar_aligned_rt_process(self, gen_int_rv): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), n_initialization_points=7, ) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index 12c1ef20..a029c49a 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -21,6 +21,7 @@ PopulationCounts, SubpopulationCounts, ) +from pyrenew.time import MMWR_WEEK, WeekCycle # Standard population structure for tests (3 subpopulations) SUBPOP_FRACTIONS = jnp.array([0.3, 0.25, 0.45]) @@ -287,7 +288,7 @@ def test_first_day_dow_reaches_calendar_aligned_latent_process(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), n_initialization_points=3, ) @@ -321,7 +322,7 @@ def test_missing_first_day_dow_for_calendar_aligned_latent_process_raises(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), n_initialization_points=3, ) @@ -349,7 +350,7 @@ def test_builder_mixed_cadence_weekly_rt_samples(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), observations=[_weekly_hosp_counts(), _daily_ed_counts()], ) @@ -587,23 +588,23 @@ def _coherence_builder( return builder -def _weekly_hosp_counts(name="hospital", period_end_dow=5): +def _weekly_hosp_counts(name="hospital", week=MMWR_WEEK): """ Build a weekly-aggregated PopulationCounts observation with PoissonNoise. Returns ------- PopulationCounts - Weekly-regular observation anchored to the specified period end dow. + Weekly-regular observation anchored to the specified ``WeekCycle``. """ return PopulationCounts( name=name, ascertainment_rate_rv=DeterministicVariable(f"{name}_ihr", 0.01), delay_distribution_rv=DeterministicPMF(f"{name}_delay", jnp.array([1.0])), noise=PoissonNoise(), - aggregation_period=7, + aggregation="weekly", reporting_schedule="regular", - period_end_dow=period_end_dow, + week=week, ) @@ -667,25 +668,25 @@ def test_weekly_rt_with_daily_observation_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_mismatched_weekly_period_end_dow_raises(self): - """Two weekly observations with different period_end_dow: rule 1 violation.""" + def test_mismatched_weekly_week_raises(self): + """Two weekly observations with different WeekCycles: rule 1 violation.""" builder = _coherence_builder( single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), observations=[ - _weekly_hosp_counts(name="hospital", period_end_dow=5), - _weekly_hosp_counts(name="other", period_end_dow=6), + _weekly_hosp_counts(name="hospital", week=MMWR_WEEK), + _weekly_hosp_counts(name="other", week=WeekCycle(start_dow=0)), ], ) - with pytest.raises(ValueError, match="must agree on period_end_dow"): + with pytest.raises(ValueError, match="must share a single WeekCycle"): builder.build() - def test_matching_weekly_period_end_dow_passes(self): - """Two weekly observations agreeing on period_end_dow: rule 1 passes.""" + def test_matching_weekly_week_passes(self): + """Two weekly observations sharing a WeekCycle: rule 1 passes.""" builder = _coherence_builder( single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), observations=[ - _weekly_hosp_counts(name="hospital", period_end_dow=5), - _weekly_hosp_counts(name="other", period_end_dow=5), + _weekly_hosp_counts(name="hospital", week=MMWR_WEEK), + _weekly_hosp_counts(name="other", week=MMWR_WEEK), ], ) model = builder.build() @@ -713,34 +714,32 @@ def test_invalid_temporal_process_step_size_raises( with pytest.raises(ValueError, match="positive integer step_size"): builder.build() - def test_calendar_week_alignment_matches_weekly_period_end_dow_passes(self): + def test_calendar_week_alignment_matches_weekly_week_passes(self): """Sunday-start weeks pair with Saturday-ending weekly observations.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), - observations=[_weekly_hosp_counts(period_end_dow=5)], + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], ) model = builder.build() assert isinstance(model, MultiSignalModel) - def test_calendar_week_alignment_mismatches_weekly_period_end_dow_raises(self): + def test_calendar_week_alignment_mismatches_weekly_week_raises(self): """A weekly Rt anchor must agree with the weekly observation anchor.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=0, + week=WeekCycle(start_dow=0), ), - observations=[_weekly_hosp_counts(period_end_dow=5)], + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], ) - with pytest.raises( - ValueError, match="weekly observations should end on period_end_dow=6" - ): + with pytest.raises(ValueError, match="disagrees with the weekly observation"): builder.build() def test_calendar_week_alignment_with_only_daily_observations_passes(self): @@ -750,20 +749,20 @@ def test_calendar_week_alignment_with_only_daily_observations_passes(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ), observations=[_daily_ed_counts()], ) model = builder.build() assert isinstance(model, MultiSignalModel) - def test_model_index_alignment_ignores_weekly_period_end_dow(self): + def test_model_index_alignment_ignores_weekly_week(self): """Model-index alignment carries no calendar anchor to compare against.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 ), - observations=[_weekly_hosp_counts(period_end_dow=5)], + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], ) model = builder.build() assert isinstance(model, MultiSignalModel) diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index f9c6628f..1ef3326f 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -7,6 +7,7 @@ import pytest from pyrenew.latent import AR1, DifferencedAR1, RandomWalk, StepwiseTemporalProcess +from pyrenew.time import MMWR_WEEK INNER_PROCESS_PARAMS = [ (AR1, {"autoreg": 0.9, "innovation_sd": 0.05}), @@ -304,25 +305,25 @@ def test_calendar_week_requires_weekly_step_size(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=3, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ) - def test_calendar_week_requires_week_start_dow(self): - """calendar_week alignment requires a declared week start.""" - with pytest.raises(ValueError, match="week_start_dow"): + def test_calendar_week_requires_week(self): + """calendar_week alignment requires a WeekCycle.""" + with pytest.raises(ValueError, match="week is required"): StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", ) - def test_model_index_rejects_week_start_dow(self): - """week_start_dow is only meaningful for calendar_week alignment.""" - with pytest.raises(ValueError, match="week_start_dow"): + def test_model_index_rejects_week(self): + """week is only meaningful for calendar_week alignment.""" + with pytest.raises(ValueError, match="week is only used"): StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, - week_start_dow=6, + week=MMWR_WEEK, ) @@ -380,7 +381,7 @@ def test_calendar_week_requires_first_day_dow_at_sample_time(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ) with numpyro.handlers.seed(rng_seed=42): with pytest.raises(ValueError, match="first_day_dow"): @@ -392,7 +393,7 @@ def test_calendar_week_alignment_with_leading_partial_week(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ) # first_day_dow=3 means day 0 is Thursday. With Sunday week starts, # days 0-2 are a leading partial week, then days 3-9 are the first @@ -415,7 +416,7 @@ def test_calendar_week_alignment_without_leading_partial_week(self): AR1(autoreg=0.9, innovation_sd=0.05), step_size=7, alignment="calendar_week", - week_start_dow=6, + week=MMWR_WEEK, ) with numpyro.handlers.seed(rng_seed=42): result = wrapper.sample( diff --git a/test/test_time.py b/test/test_time.py index a25bcd82..b6e86992 100644 --- a/test/test_time.py +++ b/test/test_time.py @@ -762,3 +762,52 @@ def test_create_date_time_spine_date_values(): assert dates[0] == dt.date(2025, 1, 1) assert dates[1] == dt.date(2025, 1, 2) assert dates[2] == dt.date(2025, 1, 3) + + +class TestWeekCycle: + """Tests for the :class:`pyrenew.time.WeekCycle` value object.""" + + def test_construction_and_end_dow(self): + """Construct with each valid ``start_dow`` and verify ``end_dow``.""" + for start in range(7): + cycle = ptime.WeekCycle(start_dow=start) + assert cycle.start_dow == start + assert cycle.end_dow == (start + 6) % 7 + + def test_mmwr_and_iso_constants(self): + """Named constants encode the MMWR and ISO weekly conventions.""" + assert ptime.MMWR_WEEK == ptime.WeekCycle(start_dow=6) + assert ptime.MMWR_WEEK.end_dow == 5 + assert ptime.ISO_WEEK == ptime.WeekCycle(start_dow=0) + assert ptime.ISO_WEEK.end_dow == 6 + + @pytest.mark.parametrize("bad", [-1, 7, 42]) + def test_invalid_start_dow_raises(self, bad): + """Out-of-range ``start_dow`` raises via :func:`validate_dow`.""" + with pytest.raises(ValueError, match="Day-of-week"): + ptime.WeekCycle(start_dow=bad) + + @pytest.mark.parametrize("bad", [None, "6", 6.0]) + def test_non_integer_start_dow_raises(self, bad): + """Non-integer ``start_dow`` raises via :func:`validate_dow`.""" + with pytest.raises(ValueError, match="must be integers"): + ptime.WeekCycle(start_dow=bad) + + def test_frozen_instance_rejects_mutation(self): + """``frozen=True`` blocks attribute assignment after construction.""" + cycle = ptime.WeekCycle(start_dow=3) + with pytest.raises(Exception): + cycle.start_dow = 4 + + def test_equality_and_hashability(self): + """Structural equality and set membership work as expected.""" + a = ptime.WeekCycle(start_dow=6) + b = ptime.WeekCycle(start_dow=6) + c = ptime.WeekCycle(start_dow=0) + assert a == b + assert a != c + assert len({a, b, c}) == 2 + + def test_not_equal_to_plain_tuple(self): + """Type-strict equality: a ``WeekCycle`` is not equal to ``(6,)``.""" + assert ptime.WeekCycle(start_dow=6) != (6,) From bfc23b434c08a16ff181bfa4bfd198b8ceec694a Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 23 Apr 2026 14:30:21 -0400 Subject: [PATCH 12/17] simplify model sample/run method - specify obs data start date --- .../tutorials/building_multisignal_models.qmd | 23 ++-- docs/tutorials/day_of_week_effects.qmd | 4 +- pyrenew/model/multisignal_model.py | 119 +++++++++++++----- .../test_population_infections_he.py | 9 +- .../test_population_infections_he_weekly.py | 15 ++- ...test_population_infections_he_weekly_rt.py | 24 ++-- test/test_pyrenew_builder.py | 75 ++++++++--- 7 files changed, 187 insertions(+), 82 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 53ec13e6..656643f8 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -39,6 +39,7 @@ import plotnine as p9 import pandas as pd import time import warnings +from datetime import date warnings.filterwarnings("ignore") @@ -269,24 +270,24 @@ weekly_baseline_rt_process = StepwiseTemporalProcess( Use `alignment="calendar_week"` when the weekly Rt blocks should align to a calendar week. Pass the same `WeekCycle` to every weekly component that must agree on the calendar: the `StepwiseTemporalProcess` above and any weekly `PopulationCounts` or `SubpopulationCounts` observation. `pyrenew.time` exports `MMWR_WEEK` (Sunday-Saturday epiweeks) and `ISO_WEEK` (Monday-Sunday); use `WeekCycle(start_dow=k)` for any other convention. -At sample or run time, pass the day of week for element 0 of the shared model axis: +At sample or run time, pass the date of the first observation day as `obs_start_date`. +The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it: ```python -first_day_dow = model.compute_first_day_dow(obs_start_dow) - model.run( - rng_key, - n_warmup=500, - n_samples=1000, - first_day_dow=first_day_dow, - hospital={"obs": hospital_obs, "first_day_dow": first_day_dow}, + num_warmup=500, + num_samples=1000, + rng_key=random.PRNGKey(42), + n_days_post_init=n_days_post_init, + population_size=population_size, + obs_start_date=date(2024, 1, 7), + hospital={"obs": hospital_obs}, ed={"obs": ed_obs}, ) ``` -The top-level `first_day_dow` is used by calendar-aligned latent temporal processes. -The observation-level `first_day_dow` is used by observation processes that need calendar alignment, such as weekly aggregation or day-of-week effects. -For `alignment="model_index"`, stepwise blocks start at model index 0 and no top-level `first_day_dow` is needed. +`obs_start_date` is required whenever any component performs calendar-aligned work: a latent temporal process with `alignment="calendar_week"`, a count observation with `aggregation="weekly"`, or any observation with a day-of-week effect. +For `alignment="model_index"` and daily observations with no day-of-week effect, `obs_start_date` can be omitted. ## Observation Processes diff --git a/docs/tutorials/day_of_week_effects.qmd b/docs/tutorials/day_of_week_effects.qmd index ccf4058f..147bfc8c 100644 --- a/docs/tutorials/day_of_week_effects.qmd +++ b/docs/tutorials/day_of_week_effects.qmd @@ -257,8 +257,8 @@ offset_df["offset"] = pd.Categorical( The two curves have the same shape but are phase-shifted: their weekend dips fall on different days. Getting `first_day_dow` right matters — a misaligned offset would attribute Monday's high to Sunday or vice versa. -When using `MultiSignalModel`, the shared time axis starts `n_init` days before the first observation. -The convenience method `model.compute_first_day_dow(obs_start_dow)` converts the known day of the week of the first observation to the correct offset for element 0 of the time axis. +When using `MultiSignalModel`, pass `obs_start_date` (the date of the first observation day) to `model.sample()` or `model.run()`. +The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it. ## Sampled observations diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 616471c7..39afed99 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -6,7 +6,10 @@ from __future__ import annotations +import datetime as dt + import jax.numpy as jnp +import numpy as np import numpyro import numpyro.handlers from jax.typing import ArrayLike @@ -14,6 +17,7 @@ from pyrenew.latent.base import BaseLatentInfectionProcess from pyrenew.metaclass import Model from pyrenew.observation.base import BaseObservationProcess +from pyrenew.time import convert_date class MultiSignalModel(Model): @@ -130,29 +134,71 @@ def pad_observations( padding = jnp.full(pad_shape, jnp.nan) return jnp.concatenate([padding, obs], axis=axis) - def compute_first_day_dow(self, obs_start_dow: int) -> int: + def _resolve_first_day_dow( + self, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None, + ) -> int | None: """ - Compute the day of the week for the start of the shared time axis. + Derive the axis-origin day-of-week from ``obs_start_date``. - The shared time axis begins ``n_init`` days before the first - observation. This method converts the known day of the week of - the first observation into the day of the week of the shared - time axis start (element 0), accounting for the initialization - period offset. + The shared daily time axis begins ``n_init`` days before the + first observation. This method converts a user-supplied + first-observation date into the day-of-week of element 0 of + that axis, accounting for the initialization-period offset. Parameters ---------- - obs_start_dow : int - Day of the week of the first observation day - (0=Monday, 6=Sunday, ISO convention). + obs_start_date + Date of the first observation day, or ``None``. Returns ------- - int - Day of the week for element 0 of the shared time axis. + int or None + Day-of-week index in ``{0, ..., 6}`` (0=Monday, ISO + convention) of element 0 of the shared axis. ``None`` when + ``obs_start_date`` is ``None``. """ + if obs_start_date is None: + return None n_init = self.latent.n_initialization_points - return (obs_start_dow - n_init) % 7 + return (convert_date(obs_start_date).weekday() - n_init) % 7 + + def _require_obs_start_date_if_weekly( + self, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None, + ) -> None: + """ + Validate that ``obs_start_date`` is supplied whenever any observation needs it. + + Observations with ``aggregation="weekly"`` or a + ``day_of_week_rv`` require a calendar anchor. Rather than + surface a downstream error from the observation itself, raise + at the model entry with the name of the offending observation. + + Parameters + ---------- + obs_start_date + Top-level calendar anchor passed by the caller. + + Raises + ------ + ValueError + If ``obs_start_date`` is ``None`` and any observation + requires a calendar anchor. + """ + if obs_start_date is not None: + return + for name, obs in self.observations.items(): + if getattr(obs, "aggregation", "daily") == "weekly": + raise ValueError( + f"obs_start_date is required when any observation uses " + f"aggregation='weekly'; observation '{name}' does." + ) + if getattr(obs, "day_of_week_rv", None) is not None: + raise ValueError( + f"obs_start_date is required when any observation uses " + f"a day-of-week effect; observation '{name}' does." + ) def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: """ @@ -179,7 +225,9 @@ def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: def validate_data( self, n_days_post_init: int, + *, subpop_fractions: ArrayLike | None = None, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None, **observation_data: dict[str, object], ) -> None: """ @@ -197,9 +245,16 @@ def validate_data( Parameters ---------- n_days_post_init - Number of days to simulate after initialization period + Number of days to simulate after initialization period. subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). + obs_start_date + Date of the first observation day. Required when any + observation uses ``aggregation="weekly"`` or a day-of-week + effect, or when a latent temporal process uses + calendar-week alignment. Converted once to the axis-origin + ``first_day_dow`` and forwarded to the latent process and + every observation. **observation_data Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample(). @@ -207,16 +262,19 @@ def validate_data( Raises ------ ValueError - If times indices are out of bounds or negative - If dense obs length doesn't match n_total - If data shapes are inconsistent + If times indices are out of bounds or negative, if dense obs + length doesn't match n_total, if data shapes are inconsistent, + or if ``obs_start_date`` is missing when an observation requires it. """ + self._require_obs_start_date_if_weekly(obs_start_date) + pop = self.latent._parse_and_validate_fractions( subpop_fractions=subpop_fractions, ) n_init = self.latent.n_initialization_points n_total = n_init + n_days_post_init + first_day_dow = self._resolve_first_day_dow(obs_start_date) for name, obs_data in observation_data.items(): if name not in self.observations: @@ -226,18 +284,10 @@ def validate_data( ) obs = self.observations[name] - if getattr(obs, "aggregation_period", 1) > 1 and ( - "first_day_dow" not in obs_data - ): - raise ValueError( - f"Observation '{name}' has aggregation_period=" - f"{obs.aggregation_period} and requires 'first_day_dow' " - f"in its observation data." - ) - obs.validate_data( n_total=n_total, n_subpops=pop.n_subpops, + first_day_dow=first_day_dow, **obs_data, ) @@ -247,7 +297,7 @@ def sample( population_size: float, *, subpop_fractions: ArrayLike | None = None, - first_day_dow: int | None = None, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None, **observation_data: dict[str, object], ) -> None: """ @@ -264,9 +314,14 @@ def sample( (from latent process) to infection counts (for observation processes). subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). - first_day_dow - Forwarded to the latent process. See - [pyrenew.latent.TemporalProcess][]. + obs_start_date + Date of the first observation day. Converted once to the + axis-origin ``first_day_dow`` (day-of-week of element 0 of + the padded axis, after subtracting ``n_init``) and forwarded + to the latent process and every observation. Required when + any observation uses ``aggregation="weekly"`` or a + day-of-week effect, or when a latent temporal process uses + calendar-week alignment. **observation_data Data for each observation process, keyed by observation name (the ``name`` attribute of each observation process). @@ -280,6 +335,9 @@ def sample( observation sites. Use ``numpyro.infer.Predictive`` for forward sampling. """ + self._require_obs_start_date_if_weekly(obs_start_date) + first_day_dow = self._resolve_first_day_dow(obs_start_date) + # Generate latent infections (proportions) latent_sample = self.latent.sample( n_days_post_init=n_days_post_init, @@ -318,6 +376,7 @@ def sample( # Sample from observation process obs_process.sample( infections=latent_infections, + first_day_dow=first_day_dow, **obs_data, ) diff --git a/test/integration/test_population_infections_he.py b/test/integration/test_population_infections_he.py index 03f5c4a3..16976028 100644 --- a/test/integration/test_population_infections_he.py +++ b/test/integration/test_population_infections_he.py @@ -8,6 +8,8 @@ from __future__ import annotations +from datetime import date + import arviz as az import jax import jax.numpy as jnp @@ -130,7 +132,6 @@ def fitted_model( ) population_size = float(daily_hosp["pop"][0]) - first_dow = 0 # 2023-11-06 is a Monday he_model.run( num_warmup=NUM_WARMUP, @@ -139,11 +140,9 @@ def fitted_model( mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, n_days_post_init=N_DAYS_FIT, population_size=population_size, + obs_start_date=date(2023, 11, 6), hospital={"obs": hosp_obs}, - ed={ - "obs": ed_obs, - "first_day_dow": he_model.compute_first_day_dow(first_dow), - }, + ed={"obs": ed_obs}, ) samples = he_model.mcmc.get_samples() diff --git a/test/integration/test_population_infections_he_weekly.py b/test/integration/test_population_infections_he_weekly.py index 01888bd7..048b0e6d 100644 --- a/test/integration/test_population_infections_he_weekly.py +++ b/test/integration/test_population_infections_he_weekly.py @@ -9,6 +9,8 @@ from __future__ import annotations +from datetime import date + import arviz as az import jax import jax.numpy as jnp @@ -27,8 +29,8 @@ NUM_WARMUP = 500 NUM_SAMPLES = 500 NUM_CHAINS = 4 -# Day 0 of the synthetic data is 2023-11-05, a Sunday (ISO dow = 6). -OBS_START_DOW = 6 +# First observation day of the synthetic data. 2023-11-05 is a Sunday (ISO dow = 6). +OBS_START_DATE = date(2023, 11, 5) def _build_hospital_obs_on_period_grid( @@ -144,7 +146,7 @@ def test_weekly_obs_alignment( weekly_hosp : pl.DataFrame Weekly hospital admissions. """ - first_day_dow = he_weekly_model.compute_first_day_dow(OBS_START_DOW) + first_day_dow = he_weekly_model._resolve_first_day_dow(OBS_START_DATE) weekly_values = jnp.array( weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 ) @@ -191,7 +193,7 @@ def fitted_model( MultiSignalModel Model with MCMC results attached. """ - first_day_dow = he_weekly_model.compute_first_day_dow(OBS_START_DOW) + first_day_dow = he_weekly_model._resolve_first_day_dow(OBS_START_DATE) weekly_values = jnp.array( weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 @@ -213,8 +215,9 @@ def fitted_model( mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, n_days_post_init=N_DAYS_FIT, population_size=population_size, - hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, - ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, ) samples = he_weekly_model.mcmc.get_samples() diff --git a/test/integration/test_population_infections_he_weekly_rt.py b/test/integration/test_population_infections_he_weekly_rt.py index 0cb24224..f1e438ac 100644 --- a/test/integration/test_population_infections_he_weekly_rt.py +++ b/test/integration/test_population_infections_he_weekly_rt.py @@ -10,6 +10,8 @@ from __future__ import annotations +from datetime import date + import arviz as az import jax import jax.numpy as jnp @@ -28,8 +30,8 @@ NUM_WARMUP = 500 NUM_SAMPLES = 500 NUM_CHAINS = 4 -# Day 0 of the synthetic data is 2023-11-05, a Sunday (ISO dow = 6). -OBS_START_DOW = 6 +# First observation day of the synthetic data. 2023-11-05 is a Sunday (ISO dow = 6). +OBS_START_DATE = date(2023, 11, 5) WEEK_START_DOW = 6 @@ -112,7 +114,7 @@ def test_coarse_rt_recorded( daily_ed : pl.DataFrame Daily ED visits. """ - first_day_dow = he_weekly_rt_model.compute_first_day_dow(OBS_START_DOW) + first_day_dow = he_weekly_rt_model._resolve_first_day_dow(OBS_START_DATE) weekly_values = jnp.array( weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 ) @@ -129,9 +131,9 @@ def test_coarse_rt_recorded( he_weekly_rt_model.sample( n_days_post_init=N_DAYS_FIT, population_size=population_size, - first_day_dow=first_day_dow, - hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, - ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, ) n_total = he_weekly_rt_model.latent.n_initialization_points + N_DAYS_FIT @@ -180,7 +182,7 @@ def fitted_model( MultiSignalModel Model with MCMC results attached. """ - first_day_dow = he_weekly_rt_model.compute_first_day_dow(OBS_START_DOW) + first_day_dow = he_weekly_rt_model._resolve_first_day_dow(OBS_START_DATE) weekly_values = jnp.array( weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 @@ -202,9 +204,9 @@ def fitted_model( mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, n_days_post_init=N_DAYS_FIT, population_size=population_size, - first_day_dow=first_day_dow, - hospital={"obs": hosp_obs, "first_day_dow": first_day_dow}, - ed={"obs": ed_obs, "first_day_dow": first_day_dow}, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, ) samples = he_weekly_rt_model.mcmc.get_samples() @@ -302,7 +304,7 @@ def test_coarse_rt_posterior_shape( posterior_dt : xarray.DataTree ArviZ DataTree with posterior group. """ - first_day_dow = fitted_model.compute_first_day_dow(OBS_START_DOW) + first_day_dow = fitted_model._resolve_first_day_dow(OBS_START_DATE) n_coarse = _expected_n_coarse(fitted_model, first_day_dow) coarse = posterior_dt.posterior["log_rt_single_coarse"] diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index a029c49a..d0761431 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -2,6 +2,8 @@ Tests for PyrenewBuilder and MultiSignalModel. """ +from datetime import date, timedelta + import jax.numpy as jnp import numpyro import pytest @@ -27,6 +29,31 @@ SUBPOP_FRACTIONS = jnp.array([0.3, 0.25, 0.45]) +def _obs_date_for_dow(target_first_day_dow: int, n_init: int) -> date: + """ + Return an ``obs_start_date`` whose axis-origin day-of-week matches. + + Given a desired ``first_day_dow`` (day-of-week of element 0 of the + padded axis) and the model's ``n_init``, pick a concrete date for + the first observation day such that + ``(obs_start_date.weekday() - n_init) % 7 == target_first_day_dow``. + + Parameters + ---------- + target_first_day_dow + Desired axis-origin day-of-week in ``{0, ..., 6}``. + n_init + Initialization-period length in days. + + Returns + ------- + datetime.date + A date with the required day-of-week, drawn from January 2024. + """ + obs_dow = (target_first_day_dow + n_init) % 7 + return date(2024, 1, 1) + timedelta(days=obs_dow) + + @pytest.fixture def simple_builder(): """ @@ -299,7 +326,7 @@ def test_first_day_dow_reaches_calendar_aligned_latent_process(self): model.sample( n_days_post_init=10, population_size=1_000_000, - first_day_dow=3, + obs_start_date=_obs_date_for_dow(target_first_day_dow=3, n_init=3), ed={"obs": None}, ) @@ -311,8 +338,8 @@ def test_first_day_dow_reaches_calendar_aligned_latent_process(self): assert jnp.allclose(log_rt[:3], log_rt[0]) assert jnp.allclose(log_rt[3:10], log_rt[3]) - def test_missing_first_day_dow_for_calendar_aligned_latent_process_raises(self): - """Calendar-aligned latent temporal processes require model-axis DOW.""" + def test_missing_obs_start_date_for_calendar_aligned_latent_process_raises(self): + """Calendar-aligned latent temporal processes require a calendar anchor.""" latent = PopulationInfections( name="PopulationInfections", gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), @@ -357,7 +384,10 @@ def test_builder_mixed_cadence_weekly_rt_samples(self): model = builder.build() n_days_post_init = 28 n_total = model.latent.n_initialization_points + n_days_post_init - first_day_dow = 6 + obs_start_date = _obs_date_for_dow( + target_first_day_dow=6, + n_init=model.latent.n_initialization_points, + ) hospital_obs = jnp.array([jnp.nan, 5.0, 7.0, 6.0], dtype=float) ed_obs = jnp.concatenate( [ @@ -371,8 +401,8 @@ def test_builder_mixed_cadence_weekly_rt_samples(self): model.sample( n_days_post_init=n_days_post_init, population_size=1_000_000, - first_day_dow=first_day_dow, - hospital={"obs": hospital_obs, "first_day_dow": first_day_dow}, + obs_start_date=obs_start_date, + hospital={"obs": hospital_obs}, ed={"obs": ed_obs}, ) @@ -545,10 +575,16 @@ def test_pad_observations_prepends_nans(self, simple_builder): (6, (6 - 3) % 7), ], ) - def test_compute_first_day_dow(self, simple_builder, obs_start_dow, expected): - """Test that compute_first_day_dow offsets by n_initialization_points.""" + def test_resolve_first_day_dow(self, simple_builder, obs_start_dow, expected): + """_resolve_first_day_dow offsets by n_initialization_points.""" model = simple_builder.build() - assert model.compute_first_day_dow(obs_start_dow) == expected + obs_date = date(2024, 1, 1) + timedelta(days=obs_start_dow) + assert model._resolve_first_day_dow(obs_date) == expected + + def test_resolve_first_day_dow_none_passthrough(self, simple_builder): + """_resolve_first_day_dow returns None when obs_start_date is None.""" + model = simple_builder.build() + assert model._resolve_first_day_dow(None) is None def test_shift_times_adds_offset(self, simple_builder): """Test that shift_times shifts by n_initialization_points.""" @@ -769,10 +805,10 @@ def test_model_index_alignment_ignores_weekly_week(self): class TestMultiSignalValidateDataAnchor: - """MultiSignalModel.validate_data sample-time anchor check for first_day_dow.""" + """MultiSignalModel.validate_data sample-time anchor check for obs_start_date.""" - def test_missing_first_day_dow_for_weekly_obs_raises(self): - """An observation with aggregation_period>1 must have first_day_dow supplied.""" + def test_missing_obs_start_date_for_weekly_obs_raises(self): + """An observation with aggregation='weekly' must have obs_start_date supplied.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 @@ -780,14 +816,14 @@ def test_missing_first_day_dow_for_weekly_obs_raises(self): observations=[_weekly_hosp_counts()], ) model = builder.build() - with pytest.raises(ValueError, match="requires 'first_day_dow'"): + with pytest.raises(ValueError, match="obs_start_date is required"): model.validate_data( n_days_post_init=28, hospital={"obs": jnp.ones(4) * 5.0}, ) - def test_first_day_dow_supplied_for_weekly_obs_passes(self): - """Supplying first_day_dow satisfies the anchor check.""" + def test_obs_start_date_supplied_for_weekly_obs_passes(self): + """Supplying obs_start_date satisfies the anchor check.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 @@ -795,13 +831,18 @@ def test_first_day_dow_supplied_for_weekly_obs_passes(self): observations=[_weekly_hosp_counts()], ) model = builder.build() + obs_start_date = _obs_date_for_dow( + target_first_day_dow=6, + n_init=model.latent.n_initialization_points, + ) model.validate_data( n_days_post_init=28, - hospital={"obs": jnp.ones(4) * 5.0, "first_day_dow": 6}, + obs_start_date=obs_start_date, + hospital={"obs": jnp.ones(4) * 5.0}, ) def test_anchor_check_skipped_for_daily_obs(self): - """Daily observations do not require first_day_dow at validate_data time.""" + """Daily observations do not require obs_start_date at validate_data time.""" builder = _coherence_builder( single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), observations=[_daily_ed_counts()], From 529c494f9228ca3e1157100e1e7675a85a72c50b Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 23 Apr 2026 16:44:51 -0400 Subject: [PATCH 13/17] fix tutorial --- .../tutorials/building_multisignal_models.qmd | 26 +++++++++++++--- pyrenew/latent/temporal_processes.py | 2 ++ pyrenew/observation/base.py | 28 +++++++++++++---- .../observation/measurement_observations.py | 5 ++++ test/test_observation_counts.py | 4 +-- test/test_observation_validation.py | 30 +++++++++++++++++-- 6 files changed, 80 insertions(+), 15 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 656643f8..047fd0dd 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -320,7 +320,7 @@ for the first 10 months of 2023 (as reported to the CDC). # | label: load-hospital-data # Load daily hospital admissions for California ca_hosp_data = datasets.load_hospital_data_for_state("CA", "2023-11-06.csv") - +obs_start_date = ca_hosp_data["dates"][0] hosp_admits = ca_hosp_data["daily_admits"] population_size = ca_hosp_data["population"] n_hosp_days = ca_hosp_data["n_days"] @@ -755,7 +755,12 @@ We define a function to prepare observation data using the model's helper method ```{python} # | label: prepare-observation-data def prepare_observation_data( - model, n_days_fit, hosp_admits, ww_data, ww_monitored_subpops + model, + n_days_fit, + obs_start_date, + hosp_admits, + ww_data, + ww_monitored_subpops, ): """ Prepare observation data for fitting. @@ -769,6 +774,8 @@ def prepare_observation_data( The model (provides padding/shifting helpers) n_days_fit : int Number of days to include in fit + obs_start_date : date + Date of first observation hosp_admits : array Hospital admissions time series ww_data : dict @@ -801,6 +808,7 @@ def prepare_observation_data( ) return { + "obs_start_date": obs_start_date, "hospital": { "obs": hosp_obs, }, @@ -826,7 +834,12 @@ jax.clear_caches() n_days_90 = 90 obs_data_90 = prepare_observation_data( - model, n_days_90, hosp_admits, ca_ww_data, ww_monitored_subpops + model, + n_days_90, + obs_start_date, + hosp_admits, + ca_ww_data, + ww_monitored_subpops, ) print(f"Fitting model with {n_days_90} days of data...") @@ -991,7 +1004,12 @@ jax.clear_caches() n_days_180 = 180 obs_data_180 = prepare_observation_data( - model, n_days_180, hosp_admits, ca_ww_data, ww_monitored_subpops + model, + n_days_180, + obs_start_date, + hosp_admits, + ca_ww_data, + ww_monitored_subpops, ) print(f"Fitting model with {n_days_180} days of data...") diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 8fbe46fd..0bd8054c 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -604,6 +604,8 @@ def sample( within each block of ``step_size`` consecutive rows. """ n_steps = self._resolve_n_coarse(n_timepoints, first_day_dow) + # first_day_dow intentionally not forwarded: inner operates on the + # coarse axis; the outer's axis-origin day-of-week does not apply. coarse = self.inner.sample( n_timepoints=n_steps, initial_value=initial_value, diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index bdddf601..1077cbb4 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -588,11 +588,13 @@ def _validate_period_end_times( Validate a period-end-time index array. Checks that all values are non-negative, within - ``[0, n_total)``, and lie on aggregation-period boundaries, - i.e., ``(t - offset) % aggregation_period == - aggregation_period - 1`` for every entry ``t``. When - ``aggregation_period == 1`` the alignment condition holds - trivially and only the bounds check runs. + ``[0, n_total)``, are at least ``offset + aggregation_period - 1`` + (so the earliest period end points at the last day of the first + complete period, not inside the leading partial period), and lie + on aggregation-period boundaries, i.e., ``(t - offset) % + aggregation_period == aggregation_period - 1`` for every entry + ``t``. When ``aggregation_period == 1`` the alignment condition + holds trivially and only the bounds check runs. Parameters ---------- @@ -611,12 +613,26 @@ def _validate_period_end_times( ------ ValueError If ``period_end_times`` contains negative values, values - ``>= n_total``, or entries that fail the alignment check. + ``>= n_total``, values below the first complete period's + final day, or entries that fail the alignment check. """ self._validate_index_array(period_end_times, n_total, "period_end_times") if aggregation_period == 1: return period_end_times = jnp.asarray(period_end_times) + + # Lower bound: the first complete period's final day. An entry below + # this would yield a negative period index under fancy-indexing + # (which JAX silently wraps to the last element of the aggregated + # array). + min_valid = offset + aggregation_period - 1 + if jnp.any(period_end_times < min_valid): + raise ValueError( + f"Observation '{self.name}': period_end_times must be >= " + f"{min_valid} (= offset + aggregation_period - 1); entries " + f"below this do not correspond to a complete aggregation period." + ) + misaligned = (period_end_times - offset) % aggregation_period != ( aggregation_period - 1 ) diff --git a/pyrenew/observation/measurement_observations.py b/pyrenew/observation/measurement_observations.py index 137e8ce0..03fe13e2 100644 --- a/pyrenew/observation/measurement_observations.py +++ b/pyrenew/observation/measurement_observations.py @@ -160,6 +160,7 @@ def sample( sensor_indices: ArrayLike, n_sensors: int, obs: ArrayLike | None = None, + **kwargs: object, ) -> ObservationSample: """ Sample measurements from observed sensors. @@ -189,6 +190,10 @@ def sample( Total number of measurement sensors. obs Observed measurements (n_obs,), or None for prior sampling. + **kwargs + Additional keyword arguments forwarded by the model + dispatch (e.g., ``first_day_dow``); ignored here because + measurement observations index the shared axis directly. Returns ------- diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 62f6e6f0..b00f842e 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -991,7 +991,7 @@ def test_weekly_irregular_aligned_times_pass(self, weekly_irregular_counts): def test_weekly_irregular_misaligned_times_raise(self, weekly_irregular_counts): """Weekly-irregular with non-Saturday period_end_times raises.""" - period_end_times = jnp.array([5, 13, 20]) + period_end_times = jnp.array([6, 12, 20]) with pytest.raises(ValueError, match="period_end_times must lie on"): weekly_irregular_counts.validate_data( n_total=28, @@ -1277,7 +1277,7 @@ def test_weekly_irregular_valid_passes( def test_weekly_irregular_misaligned_raises(self, weekly_irregular_subpop_counts): """Weekly-irregular with non-Saturday period_end_times raises.""" - period_end_times = jnp.array([5, 13, 20]) + period_end_times = jnp.array([6, 12, 20]) subpop_indices = jnp.array([0, 1, 2]) with pytest.raises(ValueError, match="period_end_times must lie on"): weekly_irregular_subpop_counts.validate_data( diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index d08d7831..d30a54db 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -780,8 +780,8 @@ def test_p1_any_in_bounds_passes(self, counts_proc): ) def test_misaligned_time_raises(self, counts_proc): - """P=7 with a non-boundary time should raise.""" - times = jnp.array([5]) + """P=7 with a non-boundary time (past the first complete period) should raise.""" + times = jnp.array([7]) with pytest.raises(ValueError, match="period_end_times must lie on"): counts_proc._validate_period_end_times( times, n_total=21, offset=0, aggregation_period=7 @@ -813,8 +813,32 @@ def test_time_at_n_total_raises(self, counts_proc): def test_error_reports_offset_and_period(self, counts_proc): """Alignment error message should include offset and aggregation_period.""" - times = jnp.array([5]) + times = jnp.array([7]) with pytest.raises(ValueError, match=r"offset=0.*aggregation_period=7"): counts_proc._validate_period_end_times( times, n_total=21, offset=0, aggregation_period=7 ) + + def test_time_before_first_complete_period_raises(self, counts_proc): + """ + An entry before the first complete period's final day must raise. + + With ``offset=6, aggregation_period=7``, ``t=5`` satisfies the + modulo boundary check ``(5 - 6) % 7 == 6`` (Python negative + modulo), but maps to ``period_idx = -1`` under fancy indexing, + which JAX silently wraps. The lower-bound check must reject it. + """ + times = jnp.array([5]) + with pytest.raises( + ValueError, match="do not correspond to a complete aggregation period" + ): + counts_proc._validate_period_end_times( + times, n_total=20, offset=6, aggregation_period=7 + ) + + def test_time_at_first_complete_period_end_passes(self, counts_proc): + """The first complete period's final day is the lower-bound edge case.""" + times = jnp.array([12]) + counts_proc._validate_period_end_times( + times, n_total=20, offset=6, aggregation_period=7 + ) From 7a67e8f7cc79296d9bbb9434ce73ca9bc9c3c8ed Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Thu, 23 Apr 2026 17:18:49 -0400 Subject: [PATCH 14/17] changes per copilot review --- pyrenew/observation/count_observations.py | 14 ++++++++++++-- test/test_observation_validation.py | 7 ++++--- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 9228ba66..24ecd6f2 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -633,7 +633,12 @@ def validate_data( return if period_end_times is None: - return + if obs is None: + return + raise ValueError( + f"Observation '{self.name}': period_end_times is required " + f"when reporting_schedule='irregular' and obs is provided." + ) offset = self._compute_period_offset(first_day_dow, self.week) self._validate_period_end_times( period_end_times, n_total, offset, self.aggregation_period @@ -850,7 +855,12 @@ def validate_data( return if period_end_times is None: - return + if obs is None: + return + raise ValueError( + f"Observation '{self.name}': period_end_times is required " + f"when reporting_schedule='irregular' and obs is provided." + ) offset = self._compute_period_offset(first_day_dow, self.week) self._validate_period_end_times( period_end_times, n_total, offset, self.aggregation_period diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index d30a54db..a52b6724 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -509,10 +509,11 @@ def test_mismatched_obs_times_raises(self, subpop_proc): obs=obs, ) - def test_obs_without_times_skips_shape_check(self, subpop_proc): - """validate_data with obs but no period_end_times should not check shapes.""" + def test_obs_without_times_raises(self, subpop_proc): + """validate_data with obs but no period_end_times should raise ValueError.""" obs = jnp.array([1.0, 2.0]) - subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) + with pytest.raises(ValueError, match="period_end_times is required"): + subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) def test_times_without_obs_skips_shape_check(self, subpop_proc): """validate_data with period_end_times but no obs should validate indices only.""" From 55272613f9a9992a71c0cddd19443f332fdf61b4 Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 24 Apr 2026 09:30:54 -0400 Subject: [PATCH 15/17] fix logic on obs_start_date for calendar alignment --- .../tutorials/building_multisignal_models.qmd | 17 +++++------- pyrenew/latent/base.py | 27 +++++++++++++++++++ pyrenew/model/multisignal_model.py | 24 ++++++++++------- test/test_pyrenew_builder.py | 27 +++++++++++++++++-- 4 files changed, 74 insertions(+), 21 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 047fd0dd..19329f3b 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -281,8 +281,7 @@ model.run( n_days_post_init=n_days_post_init, population_size=population_size, obs_start_date=date(2024, 1, 7), - hospital={"obs": hospital_obs}, - ed={"obs": ed_obs}, + ) ``` @@ -710,18 +709,14 @@ model.run( n_days_post_init=n_days, population_size=population_size, subpop_fractions=subpop_fractions, - hospital={"obs": model.pad_observations(hosp_counts)}, - wastewater={ - "obs": ww_conc, - "times": model.shift_times(ww_times), - "subpop_indices": ww_subpop_indices, - "sensor_indices": ww_sensor_indices, - "n_sensors": n_ww_sensors, - }, + **obs_data, ) samples = model.mcmc.get_samples() ``` +where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. + + ## Running the Model First we declare the population structure. We have 6 subpopulations, where 5 have wastewater monitoring and 1 does not. The subpopulations with wastewater monitoring need not be contiguous indices—they could be any subset of {0, 1, ..., n_subpops-1}. @@ -751,6 +746,8 @@ print(f"Total population: {float(jnp.sum(subpop_fractions)):.0%}") ``` We define a function to prepare observation data using the model's helper methods to align with the shared time axis. +The returned dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process, keyed by the name registered on the builder (`hospital`, `wastewater`). +At call time the returned dict is unpacked with `**` (for example, `**obs_data_90` in the 90-day fit below), forwarding each entry as a keyword argument to `model.run()`. ```{python} # | label: prepare-observation-data diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py index c84d393a..d09fd61c 100644 --- a/pyrenew/latent/base.py +++ b/pyrenew/latent/base.py @@ -338,6 +338,33 @@ def get_required_lookback(self) -> int: """ return len(self.gen_int_rv()) + def requires_calendar_anchor(self) -> bool: + """ + Report whether this latent process needs a calendar anchor at sample time. + + The default implementation inspects this instance's attributes + for :class:`pyrenew.latent.temporal_processes.TemporalProcess` + instances and returns ``True`` if any of them has + ``alignment="calendar_week"``. Subclasses may override to add + additional calendar-aligned components. + + Returns + ------- + bool + ``True`` if the caller of :meth:`sample` must supply a + ``first_day_dow`` (derived from ``obs_start_date`` at the + model entry point); ``False`` otherwise. + """ + from pyrenew.latent.temporal_processes import TemporalProcess + + for value in vars(self).values(): + if ( + isinstance(value, TemporalProcess) + and getattr(value, "alignment", None) == "calendar_week" + ): + return True + return False + @abstractmethod def sample( self, diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 39afed99..7e9a099e 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -163,17 +163,18 @@ def _resolve_first_day_dow( n_init = self.latent.n_initialization_points return (convert_date(obs_start_date).weekday() - n_init) % 7 - def _require_obs_start_date_if_weekly( + def _check_obs_start_date( self, obs_start_date: dt.date | dt.datetime | np.datetime64 | None, ) -> None: """ - Validate that ``obs_start_date`` is supplied whenever any observation needs it. + Check that ``obs_start_date`` is supplied whenever any component needs it. Observations with ``aggregation="weekly"`` or a - ``day_of_week_rv`` require a calendar anchor. Rather than - surface a downstream error from the observation itself, raise - at the model entry with the name of the offending observation. + ``day_of_week_rv`` require a calendar anchor, as do latent + processes with a calendar-week-aligned temporal process. + Rather than surface a downstream error, raise at the model + entry naming the offending component. Parameters ---------- @@ -183,8 +184,8 @@ def _require_obs_start_date_if_weekly( Raises ------ ValueError - If ``obs_start_date`` is ``None`` and any observation - requires a calendar anchor. + If ``obs_start_date`` is ``None`` and any observation or + the latent process requires a calendar anchor. """ if obs_start_date is not None: return @@ -199,6 +200,11 @@ def _require_obs_start_date_if_weekly( f"obs_start_date is required when any observation uses " f"a day-of-week effect; observation '{name}' does." ) + if self.latent.requires_calendar_anchor(): + raise ValueError( + "obs_start_date is required when the latent process uses a " + "temporal process with alignment='calendar_week'." + ) def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: """ @@ -266,7 +272,7 @@ def validate_data( length doesn't match n_total, if data shapes are inconsistent, or if ``obs_start_date`` is missing when an observation requires it. """ - self._require_obs_start_date_if_weekly(obs_start_date) + self._check_obs_start_date(obs_start_date) pop = self.latent._parse_and_validate_fractions( subpop_fractions=subpop_fractions, @@ -335,7 +341,7 @@ def sample( observation sites. Use ``numpyro.infer.Predictive`` for forward sampling. """ - self._require_obs_start_date_if_weekly(obs_start_date) + self._check_obs_start_date(obs_start_date) first_day_dow = self._resolve_first_day_dow(obs_start_date) # Generate latent infections (proportions) diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index d0761431..c89c95ef 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -339,7 +339,7 @@ def test_first_day_dow_reaches_calendar_aligned_latent_process(self): assert jnp.allclose(log_rt[3:10], log_rt[3]) def test_missing_obs_start_date_for_calendar_aligned_latent_process_raises(self): - """Calendar-aligned latent temporal processes require a calendar anchor.""" + """Calendar-aligned latent temporal processes trigger the model-entry anchor check.""" latent = PopulationInfections( name="PopulationInfections", gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), @@ -356,7 +356,9 @@ def test_missing_obs_start_date_for_calendar_aligned_latent_process_raises(self) model = MultiSignalModel(latent, {"ed": _daily_ed_counts()}) with numpyro.handlers.seed(rng_seed=42): - with pytest.raises(ValueError, match="first_day_dow"): + with pytest.raises( + ValueError, match="obs_start_date is required.*calendar_week" + ): model.sample( n_days_post_init=10, population_size=1_000_000, @@ -854,6 +856,27 @@ def test_anchor_check_skipped_for_daily_obs(self): ed={"obs": jnp.ones(n_total) * 5.0}, ) + def test_missing_obs_start_date_for_calendar_aligned_latent_raises(self): + """A calendar-week-aligned latent temporal process requires obs_start_date.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + observations=[_daily_ed_counts()], + ) + model = builder.build() + n_total = model.latent.n_initialization_points + 30 + with pytest.raises( + ValueError, match="obs_start_date is required.*calendar_week" + ): + model.validate_data( + n_days_post_init=30, + ed={"obs": jnp.ones(n_total) * 5.0}, + ) + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 72819b0b5450985436b0f0e643b0408fe375df6f Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 24 Apr 2026 12:47:22 -0400 Subject: [PATCH 16/17] changes per docs review --- .../tutorials/building_multisignal_models.qmd | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 19329f3b..c1436d68 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -247,7 +247,7 @@ This separates three model choices: - **Parameter cadence**: how often the $\mathcal{R}(t)$ temporal process samples a new latent value - **Model time axis**: the daily axis used by the renewal equation and delay convolutions -- **Observation cadence**: the scale of the likelihood for each signal, such as daily ED visits or weekly hospital admissions +- **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions The AR(1) process above samples one value per model day: @@ -255,7 +255,7 @@ The AR(1) process above samples one value per model day: baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) ``` -To parameterize $\mathcal{R}(t)$ weekly while still running the renewal equation daily, wrap the temporal process in `StepwiseTemporalProcess`. +To use a weekly $\mathcal{R}(t)$ while still running the renewal equation daily, wrap the temporal process in `StepwiseTemporalProcess`. The wrapper samples a coarse trajectory and broadcasts it to the daily model axis before the latent infection process uses it. ```python @@ -268,25 +268,14 @@ weekly_baseline_rt_process = StepwiseTemporalProcess( ``` Use `alignment="calendar_week"` when the weekly Rt blocks should align to a calendar week. +PyRenew provides class `WeekCycle` which identifies the start day of a calendar week by ISO convention, e.g. `0 == Monday`, `6 == Sunday` +and two pre-defined instances: `MMWR_WEEK: WeekCycle = WeekCycle(start_dow=6)` and `ISO_WEEK: WeekCycle = WeekCycle(start_dow=0)`. Pass the same `WeekCycle` to every weekly component that must agree on the calendar: the `StepwiseTemporalProcess` above and any weekly `PopulationCounts` or `SubpopulationCounts` observation. `pyrenew.time` exports `MMWR_WEEK` (Sunday-Saturday epiweeks) and `ISO_WEEK` (Monday-Sunday); use `WeekCycle(start_dow=k)` for any other convention. At sample or run time, pass the date of the first observation day as `obs_start_date`. -The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it: - -```python -model.run( - num_warmup=500, - num_samples=1000, - rng_key=random.PRNGKey(42), - n_days_post_init=n_days_post_init, - population_size=population_size, - obs_start_date=date(2024, 1, 7), - -) -``` - -`obs_start_date` is required whenever any component performs calendar-aligned work: a latent temporal process with `alignment="calendar_week"`, a count observation with `aggregation="weekly"`, or any observation with a day-of-week effect. +Argument `obs_start_date` is required whenever any component performs calendar-aligned work: a latent temporal process with `alignment="calendar_week"`, a count observation with `aggregation="weekly"`, or any observation with a day-of-week effect. For `alignment="model_index"` and daily observations with no day-of-week effect, `obs_start_date` can be omitted. +The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it. ## Observation Processes @@ -715,6 +704,23 @@ samples = model.mcmc.get_samples() ``` where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. +For this example, such a data dictionary would have the following structure: + +```python +{ + "obs_start_date": ... + "hospital": { + "obs": ... + }, + "wastewater": { + "obs": ... + "times": ... + "subpop_indices": ... + "sensor_indices": ... + "n_sensors": ... + }, +} +``` ## Running the Model From 469394ad119f2781eba53d6cb4d770fa078b94bb Mon Sep 17 00:00:00 2001 From: Mitzi Morris Date: Fri, 24 Apr 2026 18:26:52 -0400 Subject: [PATCH 17/17] changes per code review --- .../tutorials/building_multisignal_models.qmd | 1 - pyrenew/latent/base.py | 3 +- pyrenew/latent/temporal_processes.py | 3 +- pyrenew/model/pyrenew_builder.py | 68 ++----------------- pyrenew/time.py | 14 ++-- test/conftest.py | 36 ---------- test/test_pyrenew_builder.py | 39 ++++------- 7 files changed, 26 insertions(+), 138 deletions(-) diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index c1436d68..20d50bb7 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -270,7 +270,6 @@ weekly_baseline_rt_process = StepwiseTemporalProcess( Use `alignment="calendar_week"` when the weekly Rt blocks should align to a calendar week. PyRenew provides class `WeekCycle` which identifies the start day of a calendar week by ISO convention, e.g. `0 == Monday`, `6 == Sunday` and two pre-defined instances: `MMWR_WEEK: WeekCycle = WeekCycle(start_dow=6)` and `ISO_WEEK: WeekCycle = WeekCycle(start_dow=0)`. -Pass the same `WeekCycle` to every weekly component that must agree on the calendar: the `StepwiseTemporalProcess` above and any weekly `PopulationCounts` or `SubpopulationCounts` observation. `pyrenew.time` exports `MMWR_WEEK` (Sunday-Saturday epiweeks) and `ISO_WEEK` (Monday-Sunday); use `WeekCycle(start_dow=k)` for any other convention. At sample or run time, pass the date of the first observation day as `obs_start_date`. Argument `obs_start_date` is required whenever any component performs calendar-aligned work: a latent temporal process with `alignment="calendar_week"`, a count observation with `aggregation="weekly"`, or any observation with a day-of-week effect. diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py index d09fd61c..1b5a1eda 100644 --- a/pyrenew/latent/base.py +++ b/pyrenew/latent/base.py @@ -12,6 +12,7 @@ from jax.typing import ArrayLike from numpyro.util import not_jax_tracer +from pyrenew.latent.temporal_processes import TemporalProcess from pyrenew.metaclass import RandomVariable @@ -355,8 +356,6 @@ def requires_calendar_anchor(self) -> bool: ``first_day_dow`` (derived from ``obs_start_date`` at the model entry point); ``False`` otherwise. """ - from pyrenew.latent.temporal_processes import TemporalProcess - for value in vars(self).values(): if ( isinstance(value, TemporalProcess) diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 0bd8054c..9b44fab2 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -41,8 +41,7 @@ Wraps [pyrenew.process.RandomWalk][]. - ``StepwiseTemporalProcess``: Wrapper that parameterizes any inner ``TemporalProcess`` at a coarser cadence and broadcasts to the per-timepoint - scale by repetition. Use to match R(t) parametrization to the coarsest - observation cadence. + scale by repetition. All implementations satisfy the ``TemporalProcess`` protocol and can be used interchangeably in hierarchical infection models. diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index 728db2a4..02fc1e3f 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -10,7 +10,6 @@ from typing import Any from pyrenew.latent.base import BaseLatentInfectionProcess -from pyrenew.latent.temporal_processes import TemporalProcess from pyrenew.model.multisignal_model import MultiSignalModel from pyrenew.observation.base import BaseObservationProcess @@ -196,70 +195,14 @@ def compute_n_initialization_points(self) -> int: return n_init - def _validate_coherence(self) -> None: - """ - Enforce calendar-anchor and structural coherence across components. - - Checks: - - - All weekly observations must share a single - :class:`pyrenew.time.WeekCycle`. - - Every temporal-process ``step_size`` must be a positive integer. - - Calendar-week-aligned temporal processes must share that - :class:`WeekCycle`. - - Raises - ------ - ValueError - If any of the above checks fails. - """ - weekly_weeks = { - obs.week - for obs in self.observations.values() - if getattr(obs, "aggregation", "daily") == "weekly" - } - if len(weekly_weeks) > 1: - raise ValueError( - f"Weekly observations must share a single WeekCycle; " - f"got {sorted(weekly_weeks, key=lambda w: w.start_dow)}" - ) - obs_week = next(iter(weekly_weeks), None) - - temporal_processes = { - name: value - for name, value in self.latent_params.items() - if isinstance(value, TemporalProcess) - } - - for param_name, process in temporal_processes.items(): - step_size = getattr(process, "step_size", 1) - if not isinstance(step_size, int) or step_size < 1: - raise ValueError( - f"Temporal process '{param_name}' must expose a positive " - f"integer step_size; got {step_size!r}" - ) - if ( - getattr(process, "alignment", None) == "calendar_week" - and obs_week is not None - ): - proc_week = getattr(process, "week", None) - if proc_week != obs_week: - raise ValueError( - f"Temporal process '{param_name}' has week={proc_week!r}, " - f"which disagrees with the weekly observation " - f"week={obs_week!r}" - ) - def build(self) -> MultiSignalModel: """ Build the multi-signal model with computed n_initialization_points. This method: - 1. Enforces coherence between R(t) cadence and observation cadences - 2. Computes n_initialization_points from all components - 3. Constructs the latent process with the computed value - 4. Creates a MultiSignalModel with automatic infection routing - 5. Validates that observation/latent types are compatible + 1. Computes n_initialization_points from all components + 2. Constructs the latent process with the computed value + 3. Creates a MultiSignalModel with automatic infection routing Can be called multiple times to create multiple model instances. @@ -271,14 +214,11 @@ def build(self) -> MultiSignalModel: Raises ------ ValueError - If latent process not configured, or if R(t) and observation - cadences are incoherent. + If latent process not configured. """ if self.latent_class is None: raise ValueError("Must call configure_latent() before build()") - self._validate_coherence() - # Compute n_initialization_points n_init = self.compute_n_initialization_points() diff --git a/pyrenew/time.py b/pyrenew/time.py index 80ee0627..309a9805 100644 --- a/pyrenew/time.py +++ b/pyrenew/time.py @@ -16,17 +16,15 @@ ``TemporalProcess.sample`` protocol. - :class:`WeekCycle` — a 7-day calendar cycle identified by its - ``start_dow``. Construction-time modeling choice, shared by every - model component that must agree on the same calendar week - (weekly :class:`CountObservation` and calendar-week-aligned - :class:`StepwiseTemporalProcess`). The module-level constants + ``start_dow``. Construction-time modeling choice, set independently on + each model component that does calendar-week work (weekly + :class:`CountObservation` and calendar-week-aligned + :class:`StepwiseTemporalProcess`). Different components can use + different cycles; each performs its own trimming and aggregation + relative to the shared daily axis. The module-level constants :data:`MMWR_WEEK` (Sun-Sat) and :data:`ISO_WEEK` (Mon-Sun) cover the common conventions; custom cycles use ``WeekCycle(start_dow=k)``. -``PyrenewBuilder._validate_coherence`` enforces that all weekly -observations share one ``WeekCycle`` and that any calendar-week-aligned -temporal process uses the same cycle. - Worked example: ``first_day_dow=3`` (Thursday), ``week=MMWR_WEEK`` (Sunday start), 17-day axis:: diff --git a/test/conftest.py b/test/conftest.py index 197f6d7d..e9100532 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -515,26 +515,6 @@ def sample(self, n_timepoints, n_processes=1, **kwargs): return jnp.zeros((n_timepoints, n_processes)) -class InvalidStepSizeTemporalProcess: - """Temporal process stub with invalid builder-inspected metadata.""" - - step_size = 0 - - def sample(self, **kwargs): - """ - Return an arbitrary array. - - Builder validation should reject this process before ``sample`` runs, - so the returned value is irrelevant. - - Returns - ------- - jnp.ndarray - Shape ``(1, 1)`` array of zeros. - """ - return jnp.zeros((1, 1)) - - @pytest.fixture def wrong_shape_temporal_process_cls(): """ @@ -560,19 +540,3 @@ def constant_temporal_process(): Instance whose ``sample`` returns zeros of the requested shape. """ return ConstantTemporalProcess() - - -@pytest.fixture -def invalid_step_size_temporal_process(): - """ - Temporal-process stub that advertises ``step_size=0``. - - Used to exercise ``PyrenewBuilder._validate_coherence`` rejection of - temporal processes whose metadata is structurally invalid. - - Returns - ------- - InvalidStepSizeTemporalProcess - Instance with ``step_size=0``. - """ - return InvalidStepSizeTemporalProcess() diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index c89c95ef..d2315530 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -663,8 +663,8 @@ def _daily_ed_counts(name="ed"): ) -class TestBuilderCoherence: - """PyrenewBuilder._validate_coherence enforcement at build() time.""" +class TestBuilderConfigurations: + """PyrenewBuilder.build() accepts varied R(t) and observation cadences.""" def test_daily_rt_with_daily_observation_passes(self): """step_size=1 and P=1: valid.""" @@ -706,8 +706,8 @@ def test_weekly_rt_with_daily_observation_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_mismatched_weekly_week_raises(self): - """Two weekly observations with different WeekCycles: rule 1 violation.""" + def test_mismatched_weekly_week_passes(self): + """Weekly observations can use different WeekCycles; each aggregates independently.""" builder = _coherence_builder( single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), observations=[ @@ -715,11 +715,11 @@ def test_mismatched_weekly_week_raises(self): _weekly_hosp_counts(name="other", week=WeekCycle(start_dow=0)), ], ) - with pytest.raises(ValueError, match="must share a single WeekCycle"): - builder.build() + model = builder.build() + assert isinstance(model, MultiSignalModel) def test_matching_weekly_week_passes(self): - """Two weekly observations sharing a WeekCycle: rule 1 passes.""" + """Two weekly observations sharing a WeekCycle build normally.""" builder = _coherence_builder( single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), observations=[ @@ -741,17 +741,6 @@ def test_arbitrary_step_size_with_weekly_observation_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_invalid_temporal_process_step_size_raises( - self, invalid_step_size_temporal_process - ): - """Temporal process metadata must expose a positive integer step_size.""" - builder = _coherence_builder( - single_rt_process=invalid_step_size_temporal_process, - observations=[_daily_ed_counts()], - ) - with pytest.raises(ValueError, match="positive integer step_size"): - builder.build() - def test_calendar_week_alignment_matches_weekly_week_passes(self): """Sunday-start weeks pair with Saturday-ending weekly observations.""" builder = _coherence_builder( @@ -766,8 +755,8 @@ def test_calendar_week_alignment_matches_weekly_week_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_calendar_week_alignment_mismatches_weekly_week_raises(self): - """A weekly Rt anchor must agree with the weekly observation anchor.""" + def test_calendar_week_alignment_mismatches_weekly_week_passes(self): + """A weekly Rt anchor can differ from a weekly observation anchor.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), @@ -777,11 +766,11 @@ def test_calendar_week_alignment_mismatches_weekly_week_raises(self): ), observations=[_weekly_hosp_counts(week=MMWR_WEEK)], ) - with pytest.raises(ValueError, match="disagrees with the weekly observation"): - builder.build() + model = builder.build() + assert isinstance(model, MultiSignalModel) def test_calendar_week_alignment_with_only_daily_observations_passes(self): - """No weekly observation means no calendar-anchor agreement to enforce.""" + """Calendar-week-aligned R(t) pairs with daily observations.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), @@ -794,8 +783,8 @@ def test_calendar_week_alignment_with_only_daily_observations_passes(self): model = builder.build() assert isinstance(model, MultiSignalModel) - def test_model_index_alignment_ignores_weekly_week(self): - """Model-index alignment carries no calendar anchor to compare against.""" + def test_model_index_alignment_with_weekly_observation_passes(self): + """Model-index-aligned R(t) pairs with weekly observations.""" builder = _coherence_builder( single_rt_process=StepwiseTemporalProcess( AR1(autoreg=0.9, innovation_sd=0.05), step_size=7