diff --git a/_typos.toml b/_typos.toml index 47865732..8cf4ff9c 100644 --- a/_typos.toml +++ b/_typos.toml @@ -3,3 +3,4 @@ # words that should not be corrected arange = "arange" lod = "lod" +dows = "dows" diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 23857e3a..20d50bb7 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -39,6 +39,7 @@ import plotnine as p9 import pandas as pd import time import warnings +from datetime import date warnings.filterwarnings("ignore") @@ -67,6 +68,7 @@ from pyrenew.latent import ( SubpopulationInfections, AR1, RandomWalk, + StepwiseTemporalProcess, GammaGroupSdPrior, HierarchicalNormalPrior, ) @@ -79,6 +81,7 @@ from pyrenew.observation import ( MeasurementNoise, NegativeBinomialNoise, ) +from pyrenew.time import MMWR_WEEK ``` ## Overview @@ -237,6 +240,42 @@ baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) subpop_rt_deviation_process = RandomWalk(innovation_sd=0.025) ``` +### Choosing the Rt Parameter Cadence + +The renewal equation is evaluated on the model's daily time axis, but the temporal process for $\mathcal{R}(t)$ does not have to sample a new parameter every day. +This separates three model choices: + +- **Parameter cadence**: how often the $\mathcal{R}(t)$ temporal process samples a new latent value +- **Model time axis**: the daily axis used by the renewal equation and delay convolutions +- **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions + +The AR(1) process above samples one value per model day: + +```python +baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) +``` + +To use a weekly $\mathcal{R}(t)$ while still running the renewal equation daily, wrap the temporal process in `StepwiseTemporalProcess`. +The wrapper samples a coarse trajectory and broadcasts it to the daily model axis before the latent infection process uses it. + +```python +weekly_baseline_rt_process = StepwiseTemporalProcess( + inner=AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, +) +``` + +Use `alignment="calendar_week"` when the weekly Rt blocks should align to a calendar week. +PyRenew provides class `WeekCycle` which identifies the start day of a calendar week by ISO convention, e.g. `0 == Monday`, `6 == Sunday` +and two pre-defined instances: `MMWR_WEEK: WeekCycle = WeekCycle(start_dow=6)` and `ISO_WEEK: WeekCycle = WeekCycle(start_dow=0)`. +`pyrenew.time` exports `MMWR_WEEK` (Sunday-Saturday epiweeks) and `ISO_WEEK` (Monday-Sunday); use `WeekCycle(start_dow=k)` for any other convention. +At sample or run time, pass the date of the first observation day as `obs_start_date`. +Argument `obs_start_date` is required whenever any component performs calendar-aligned work: a latent temporal process with `alignment="calendar_week"`, a count observation with `aggregation="weekly"`, or any observation with a day-of-week effect. +For `alignment="model_index"` and daily observations with no day-of-week effect, `obs_start_date` can be omitted. +The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it. + ## Observation Processes Observation processes transform latent infections into observable signals and define the statistical model linking predictions to data. Each observation process: @@ -268,7 +307,7 @@ for the first 10 months of 2023 (as reported to the CDC). # | label: load-hospital-data # Load daily hospital admissions for California ca_hosp_data = datasets.load_hospital_data_for_state("CA", "2023-11-06.csv") - +obs_start_date = ca_hosp_data["dates"][0] hosp_admits = ca_hosp_data["daily_admits"] population_size = ca_hosp_data["population"] n_hosp_days = ca_hosp_data["n_days"] @@ -658,18 +697,31 @@ model.run( n_days_post_init=n_days, population_size=population_size, subpop_fractions=subpop_fractions, - hospital={"obs": model.pad_observations(hosp_counts)}, - wastewater={ - "obs": ww_conc, - "times": model.shift_times(ww_times), - "subpop_indices": ww_subpop_indices, - "sensor_indices": ww_sensor_indices, - "n_sensors": n_ww_sensors, - }, + **obs_data, ) samples = model.mcmc.get_samples() ``` +where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. +For this example, such a data dictionary would have the following structure: + +```python +{ + "obs_start_date": ... + "hospital": { + "obs": ... + }, + "wastewater": { + "obs": ... + "times": ... + "subpop_indices": ... + "sensor_indices": ... + "n_sensors": ... + }, +} +``` + + ## Running the Model First we declare the population structure. We have 6 subpopulations, where 5 have wastewater monitoring and 1 does not. The subpopulations with wastewater monitoring need not be contiguous indices—they could be any subset of {0, 1, ..., n_subpops-1}. @@ -699,11 +751,18 @@ print(f"Total population: {float(jnp.sum(subpop_fractions)):.0%}") ``` We define a function to prepare observation data using the model's helper methods to align with the shared time axis. +The returned dictionary is structured to match the keyword arguments of `model.run()`: `obs_start_date` at the top level, plus one sub-dictionary per observation process, keyed by the name registered on the builder (`hospital`, `wastewater`). +At call time the returned dict is unpacked with `**` (for example, `**obs_data_90` in the 90-day fit below), forwarding each entry as a keyword argument to `model.run()`. ```{python} # | label: prepare-observation-data def prepare_observation_data( - model, n_days_fit, hosp_admits, ww_data, ww_monitored_subpops + model, + n_days_fit, + obs_start_date, + hosp_admits, + ww_data, + ww_monitored_subpops, ): """ Prepare observation data for fitting. @@ -717,6 +776,8 @@ def prepare_observation_data( The model (provides padding/shifting helpers) n_days_fit : int Number of days to include in fit + obs_start_date : date + Date of first observation hosp_admits : array Hospital admissions time series ww_data : dict @@ -749,6 +810,7 @@ def prepare_observation_data( ) return { + "obs_start_date": obs_start_date, "hospital": { "obs": hosp_obs, }, @@ -774,7 +836,12 @@ jax.clear_caches() n_days_90 = 90 obs_data_90 = prepare_observation_data( - model, n_days_90, hosp_admits, ca_ww_data, ww_monitored_subpops + model, + n_days_90, + obs_start_date, + hosp_admits, + ca_ww_data, + ww_monitored_subpops, ) print(f"Fitting model with {n_days_90} days of data...") @@ -939,7 +1006,12 @@ jax.clear_caches() n_days_180 = 180 obs_data_180 = prepare_observation_data( - model, n_days_180, hosp_admits, ca_ww_data, ww_monitored_subpops + model, + n_days_180, + obs_start_date, + hosp_admits, + ca_ww_data, + ww_monitored_subpops, ) print(f"Fitting model with {n_days_180} days of data...") diff --git a/docs/tutorials/day_of_week_effects.qmd b/docs/tutorials/day_of_week_effects.qmd index ccf4058f..147bfc8c 100644 --- a/docs/tutorials/day_of_week_effects.qmd +++ b/docs/tutorials/day_of_week_effects.qmd @@ -257,8 +257,8 @@ offset_df["offset"] = pd.Categorical( The two curves have the same shape but are phase-shifted: their weekend dips fall on different days. Getting `first_day_dow` right matters — a misaligned offset would attribute Monday's high to Sunday or vice versa. -When using `MultiSignalModel`, the shared time axis starts `n_init` days before the first observation. -The convenience method `model.compute_first_day_dow(obs_start_dow)` converts the known day of the week of the first observation to the correct offset for element 0 of the time axis. +When using `MultiSignalModel`, pass `obs_start_date` (the date of the first observation day) to `model.sample()` or `model.run()`. +The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it. ## Sampled observations diff --git a/pyproject.toml b/pyproject.toml index 3f84cf70..e4465682 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,11 @@ known_first_party = ["pyrenew", "test"] [tool.deptry.per_rule_ignores] DEP004 = ["arviz", "pytest", "scipy", "bs4"] +[tool.pytest.ini_options] +markers = [ + "integration: integration tests that fit models via MCMC (deselect with '-m \"not integration\"')", +] + [tool.ruff] fix = true diff --git a/pyrenew/arrayutils.py b/pyrenew/arrayutils.py index 64750d42..973466a4 100644 --- a/pyrenew/arrayutils.py +++ b/pyrenew/arrayutils.py @@ -24,6 +24,28 @@ def __repr__(self) -> str: return f"PeriodicProcessSample(value={self})" +def require_shape(arr: ArrayLike, expected: tuple[int, ...], label: str) -> None: + """ + Validate that ``arr.shape`` equals ``expected``. + + Parameters + ---------- + arr + Array whose shape is being checked. + expected + Required shape. + label + Name of the producing component, used in the error message. + + Raises + ------ + ValueError + If ``arr.shape`` does not equal ``expected``. + """ + if arr.shape != expected: + raise ValueError(f"{label} must return shape {expected}; got {arr.shape}") + + def tile_until_n( data: ArrayLike, n_timepoints: int, diff --git a/pyrenew/datasets/datagen_he_CA_120.py b/pyrenew/datasets/datagen_he_CA_126.py similarity index 81% rename from pyrenew/datasets/datagen_he_CA_120.py rename to pyrenew/datasets/datagen_he_CA_126.py index 3ccfeaf9..c1358060 100644 --- a/pyrenew/datasets/datagen_he_CA_120.py +++ b/pyrenew/datasets/datagen_he_CA_126.py @@ -1,13 +1,18 @@ # numpydoc ignore=ES01,SA01,EX01 """ -Generate 120-day synthetic CA hospital + ED visit data from known R(t). +Generate 126-day synthetic CA hospital + ED visit data from known R(t). This script defines R(t) directly, runs a renewal equation forward, and convolves with signal-specific delay PMFs to produce two observation streams. All true parameters are saved alongside the synthetic observations so tutorials can demonstrate posterior recovery. -Outputs (in pyrenew/datasets/synthetic_CA_120/): +The start date (2023-11-05, Sunday) and length (126 days = 18 weeks) +are chosen so that the daily time axis aligns exactly with 18 complete +MMWR epiweeks, with no days trimmed at either end during weekly +aggregation. + +Outputs (in pyrenew/datasets/synthetic_CA_126/): - true_parameters.json - daily_infections.csv - daily_ed_visits.csv @@ -16,7 +21,7 @@ Run from repo root:: - python -m pyrenew.datasets.datagen_he_CA_120 + python -m pyrenew.datasets.datagen_he_CA_126 """ from __future__ import annotations @@ -33,7 +38,7 @@ from pyrenew.time import daily_to_mmwr_epiweekly, get_sequential_day_of_week_indices REPO_ROOT = Path(__file__).resolve().parents[2] -OUTPUT_DIR = REPO_ROOT / "pyrenew" / "datasets" / "synthetic_CA_120" +OUTPUT_DIR = REPO_ROOT / "pyrenew" / "datasets" / "synthetic_CA_126" RNG_SEED = 20240101 POPULATION = 39_512_223 @@ -123,10 +128,11 @@ IEDR = 0.0075 I0_PER_CAPITA = 5e-4 NEGBINOM_CONCENTRATION_HOSP = 350.0 +NEGBINOM_CONCENTRATION_HOSP_WEEKLY = 100.0 NEGBINOM_CONCENTRATION_ED = 50.0 DOW_EFFECTS = np.array([1.15, 1.12, 1.08, 1.05, 0.98, 0.82, 0.80]) -START_DATE = date(2023, 11, 6) +START_DATE = date(2023, 11, 5) N_INIT = max(50, len(HOSP_DELAY_PMF), len(ED_DELAY_PMF)) @@ -134,8 +140,8 @@ def build_true_rt() -> np.ndarray: """ Build a piecewise-linear true R(t) trajectory. - Phases: decline from 1.2 to 0.8 (60 d), rise from 0.8 to 1.1 (40 d), - decline from 1.1 to 0.85 (20 d). + Phases: decline from 1.2 to 0.8 (63 d), rise from 0.8 to 1.1 (42 d), + decline from 1.1 to 0.85 (21 d). Total 126 days = 18 weeks. Returns ------- @@ -143,9 +149,9 @@ def build_true_rt() -> np.ndarray: R(t) trajectory. """ segments = [ - (60, 1.2, 0.8), - (40, 0.8, 1.1), - (20, 1.1, 0.85), + (63, 1.2, 0.8), + (42, 0.8, 1.1), + (21, 1.1, 0.85), ] rt = np.concatenate( [ @@ -204,7 +210,7 @@ def run_renewal( def apply_day_of_week_effects( - values: np.ndarray, dow_effects: np.ndarray, first_dow: int + values: np.ndarray, dow_effects: np.ndarray, first_day_dow: int ) -> np.ndarray: """ Apply multiplicative day-of-week effects to daily values. @@ -216,7 +222,7 @@ def apply_day_of_week_effects( dow_effects : np.ndarray Multiplicative effects for each day (length 7, ISO convention: 0 = Monday, 6 = Sunday). Should sum to 7 to preserve weekly totals. - first_dow : int + first_day_dow : int ISO day-of-week of the first element in values. Returns @@ -224,7 +230,7 @@ def apply_day_of_week_effects( np.ndarray Adjusted daily values. """ - day_indices = get_sequential_day_of_week_indices(first_dow, len(values)) + day_indices = get_sequential_day_of_week_indices(first_day_dow, len(values)) return values * dow_effects[day_indices] @@ -255,32 +261,29 @@ def sample_negbinom( return rng.negative_binomial(n=concentration, p=p) -def aggregate_to_epiweeks( - daily_values: np.ndarray, +def build_weekly_hosp_frame( + weekly_values: np.ndarray, start_date: date, ) -> pl.DataFrame: """ - Aggregate daily counts to MMWR epiweek totals (Sun-Sat, labeled by Saturday). - - Only complete 7-day weeks are kept. + Build a weekly hospital admissions frame with MMWR week-ending dates. Parameters ---------- - daily_values : np.ndarray - Daily count time series. + weekly_values : np.ndarray + One value per MMWR epiweek (aggregated by the caller). start_date : date - Date of the first element. + Date of day 0 of the daily time series the weekly values were + aggregated from. Used to compute the date of the first + Saturday (week-ending day). Returns ------- pl.DataFrame Columns: week_end, weekly_hosp_admits. """ - first_dow = start_date.weekday() - weekly_values = daily_to_mmwr_epiweekly( - np.asarray(daily_values), input_data_first_dow=first_dow - ) - days_to_first_sunday = (6 - first_dow) % 7 + first_day_dow = start_date.weekday() + days_to_first_sunday = (6 - first_day_dow) % 7 first_week_end = start_date + timedelta(days=days_to_first_sunday + 6) n_weeks = len(weekly_values) week_ends = [first_week_end + timedelta(weeks=i) for i in range(n_weeks)] @@ -305,7 +308,7 @@ def generate() -> None: infections_obs = infections_full[N_INIT:] obs_dates = [START_DATE + timedelta(days=i) for i in range(n_days)] - first_dow = START_DATE.weekday() + first_day_dow = START_DATE.weekday() expected_hosp_daily, _ = compute_delay_ascertained_incidence( latent_incidence=infections_full, @@ -319,7 +322,18 @@ def generate() -> None: expected_hosp_daily, NEGBINOM_CONCENTRATION_HOSP, rng ) - weekly_hosp = aggregate_to_epiweeks(hosp_daily_obs, START_DATE) + # Weekly hospital admissions: aggregate daily *expected* values to MMWR + # epiweeks, then apply one NegBin at the weekly scale. This matches the + # generative assumption of PopulationCounts(aggregation_period=7) with a + # single NegativeBinomialNoise at the reporting cadence. + expected_hosp_weekly = np.asarray( + daily_to_mmwr_epiweekly(expected_hosp_daily, input_data_first_dow=first_day_dow) + ) + expected_hosp_weekly = np.maximum(expected_hosp_weekly, 1.0) + hosp_weekly_obs = sample_negbinom( + expected_hosp_weekly, NEGBINOM_CONCENTRATION_HOSP_WEEKLY, rng + ) + weekly_hosp = build_weekly_hosp_frame(hosp_weekly_obs, START_DATE) hosp_daily_df = pl.DataFrame( { @@ -338,7 +352,7 @@ def generate() -> None: pad=True, ) expected_ed = expected_ed[N_INIT:] - expected_ed = apply_day_of_week_effects(expected_ed, DOW_EFFECTS, first_dow) + expected_ed = apply_day_of_week_effects(expected_ed, DOW_EFFECTS, first_day_dow) expected_ed = np.maximum(expected_ed, 1.0) ed_obs = sample_negbinom(expected_ed, NEGBINOM_CONCENTRATION_ED, rng) @@ -380,14 +394,15 @@ def generate() -> None: "generation_interval_pmf": GEN_INT_PMF.tolist(), "i0_per_capita": I0_PER_CAPITA, "rt_trajectory": { - "phase_1": {"days": 60, "start": 1.2, "end": 0.8}, - "phase_2": {"days": 40, "start": 0.8, "end": 1.1}, - "phase_3": {"days": 20, "start": 1.1, "end": 0.85}, + "phase_1": {"days": 63, "start": 1.2, "end": 0.8}, + "phase_2": {"days": 42, "start": 0.8, "end": 1.1}, + "phase_3": {"days": 21, "start": 1.1, "end": 0.85}, }, "hospitalizations": { "ihr": IHR, "delay_pmf_source": "infection_admission_interval.tsv", - "negbinom_concentration": NEGBINOM_CONCENTRATION_HOSP, + "negbinom_concentration_daily": NEGBINOM_CONCENTRATION_HOSP, + "negbinom_concentration_weekly": NEGBINOM_CONCENTRATION_HOSP_WEEKLY, "temporal_resolutions": ["daily", "weekly_epiweek"], }, "ed_visits": { diff --git a/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv b/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv deleted file mode 100644 index 8bfdc09a..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_ed_visits.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,geo_value,disease,ed_visits -2023-11-06,CA,COVID-19,144 -2023-11-07,CA,COVID-19,120 -2023-11-08,CA,COVID-19,151 -2023-11-09,CA,COVID-19,154 -2023-11-10,CA,COVID-19,134 -2023-11-11,CA,COVID-19,133 -2023-11-12,CA,COVID-19,158 -2023-11-13,CA,COVID-19,189 -2023-11-14,CA,COVID-19,208 -2023-11-15,CA,COVID-19,307 -2023-11-16,CA,COVID-19,259 -2023-11-17,CA,COVID-19,350 -2023-11-18,CA,COVID-19,369 -2023-11-19,CA,COVID-19,381 -2023-11-20,CA,COVID-19,401 -2023-11-21,CA,COVID-19,629 -2023-11-22,CA,COVID-19,449 -2023-11-23,CA,COVID-19,551 -2023-11-24,CA,COVID-19,422 -2023-11-25,CA,COVID-19,409 -2023-11-26,CA,COVID-19,349 -2023-11-27,CA,COVID-19,593 -2023-11-28,CA,COVID-19,740 -2023-11-29,CA,COVID-19,700 -2023-11-30,CA,COVID-19,767 -2023-12-01,CA,COVID-19,703 -2023-12-02,CA,COVID-19,688 -2023-12-03,CA,COVID-19,604 -2023-12-04,CA,COVID-19,1018 -2023-12-05,CA,COVID-19,814 -2023-12-06,CA,COVID-19,851 -2023-12-07,CA,COVID-19,890 -2023-12-08,CA,COVID-19,644 -2023-12-09,CA,COVID-19,711 -2023-12-10,CA,COVID-19,614 -2023-12-11,CA,COVID-19,1092 -2023-12-12,CA,COVID-19,815 -2023-12-13,CA,COVID-19,539 -2023-12-14,CA,COVID-19,729 -2023-12-15,CA,COVID-19,711 -2023-12-16,CA,COVID-19,638 -2023-12-17,CA,COVID-19,502 -2023-12-18,CA,COVID-19,894 -2023-12-19,CA,COVID-19,716 -2023-12-20,CA,COVID-19,599 -2023-12-21,CA,COVID-19,536 -2023-12-22,CA,COVID-19,605 -2023-12-23,CA,COVID-19,354 -2023-12-24,CA,COVID-19,620 -2023-12-25,CA,COVID-19,708 -2023-12-26,CA,COVID-19,485 -2023-12-27,CA,COVID-19,528 -2023-12-28,CA,COVID-19,463 -2023-12-29,CA,COVID-19,322 -2023-12-30,CA,COVID-19,166 -2023-12-31,CA,COVID-19,178 -2024-01-01,CA,COVID-19,296 -2024-01-02,CA,COVID-19,299 -2024-01-03,CA,COVID-19,219 -2024-01-04,CA,COVID-19,188 -2024-01-05,CA,COVID-19,147 -2024-01-06,CA,COVID-19,68 -2024-01-07,CA,COVID-19,98 -2024-01-08,CA,COVID-19,113 -2024-01-09,CA,COVID-19,117 -2024-01-10,CA,COVID-19,74 -2024-01-11,CA,COVID-19,69 -2024-01-12,CA,COVID-19,84 -2024-01-13,CA,COVID-19,58 -2024-01-14,CA,COVID-19,26 -2024-01-15,CA,COVID-19,42 -2024-01-16,CA,COVID-19,53 -2024-01-17,CA,COVID-19,62 -2024-01-18,CA,COVID-19,42 -2024-01-19,CA,COVID-19,47 -2024-01-20,CA,COVID-19,32 -2024-01-21,CA,COVID-19,29 -2024-01-22,CA,COVID-19,19 -2024-01-23,CA,COVID-19,21 -2024-01-24,CA,COVID-19,28 -2024-01-25,CA,COVID-19,20 -2024-01-26,CA,COVID-19,9 -2024-01-27,CA,COVID-19,16 -2024-01-28,CA,COVID-19,15 -2024-01-29,CA,COVID-19,34 -2024-01-30,CA,COVID-19,30 -2024-01-31,CA,COVID-19,21 -2024-02-01,CA,COVID-19,19 -2024-02-02,CA,COVID-19,18 -2024-02-03,CA,COVID-19,24 -2024-02-04,CA,COVID-19,10 -2024-02-05,CA,COVID-19,25 -2024-02-06,CA,COVID-19,16 -2024-02-07,CA,COVID-19,27 -2024-02-08,CA,COVID-19,27 -2024-02-09,CA,COVID-19,29 -2024-02-10,CA,COVID-19,21 -2024-02-11,CA,COVID-19,17 -2024-02-12,CA,COVID-19,26 -2024-02-13,CA,COVID-19,25 -2024-02-14,CA,COVID-19,27 -2024-02-15,CA,COVID-19,24 -2024-02-16,CA,COVID-19,20 -2024-02-17,CA,COVID-19,21 -2024-02-18,CA,COVID-19,38 -2024-02-19,CA,COVID-19,41 -2024-02-20,CA,COVID-19,38 -2024-02-21,CA,COVID-19,42 -2024-02-22,CA,COVID-19,41 -2024-02-23,CA,COVID-19,32 -2024-02-24,CA,COVID-19,32 -2024-02-25,CA,COVID-19,27 -2024-02-26,CA,COVID-19,54 -2024-02-27,CA,COVID-19,45 -2024-02-28,CA,COVID-19,41 -2024-02-29,CA,COVID-19,26 -2024-03-01,CA,COVID-19,23 -2024-03-02,CA,COVID-19,19 -2024-03-03,CA,COVID-19,18 -2024-03-04,CA,COVID-19,24 diff --git a/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv deleted file mode 100644 index fd32f156..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_hospital_admissions.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,geo_value,daily_hosp_admits,pop -2023-11-06,CA,28,39512223 -2023-11-07,CA,31,39512223 -2023-11-08,CA,43,39512223 -2023-11-09,CA,30,39512223 -2023-11-10,CA,57,39512223 -2023-11-11,CA,47,39512223 -2023-11-12,CA,66,39512223 -2023-11-13,CA,94,39512223 -2023-11-14,CA,66,39512223 -2023-11-15,CA,81,39512223 -2023-11-16,CA,93,39512223 -2023-11-17,CA,121,39512223 -2023-11-18,CA,146,39512223 -2023-11-19,CA,118,39512223 -2023-11-20,CA,175,39512223 -2023-11-21,CA,208,39512223 -2023-11-22,CA,197,39512223 -2023-11-23,CA,223,39512223 -2023-11-24,CA,226,39512223 -2023-11-25,CA,257,39512223 -2023-11-26,CA,276,39512223 -2023-11-27,CA,276,39512223 -2023-11-28,CA,300,39512223 -2023-11-29,CA,321,39512223 -2023-11-30,CA,328,39512223 -2023-12-01,CA,351,39512223 -2023-12-02,CA,354,39512223 -2023-12-03,CA,390,39512223 -2023-12-04,CA,441,39512223 -2023-12-05,CA,415,39512223 -2023-12-06,CA,381,39512223 -2023-12-07,CA,413,39512223 -2023-12-08,CA,425,39512223 -2023-12-09,CA,418,39512223 -2023-12-10,CA,478,39512223 -2023-12-11,CA,491,39512223 -2023-12-12,CA,514,39512223 -2023-12-13,CA,489,39512223 -2023-12-14,CA,423,39512223 -2023-12-15,CA,509,39512223 -2023-12-16,CA,560,39512223 -2023-12-17,CA,457,39512223 -2023-12-18,CA,493,39512223 -2023-12-19,CA,537,39512223 -2023-12-20,CA,505,39512223 -2023-12-21,CA,493,39512223 -2023-12-22,CA,430,39512223 -2023-12-23,CA,495,39512223 -2023-12-24,CA,468,39512223 -2023-12-25,CA,423,39512223 -2023-12-26,CA,419,39512223 -2023-12-27,CA,377,39512223 -2023-12-28,CA,356,39512223 -2023-12-29,CA,395,39512223 -2023-12-30,CA,390,39512223 -2023-12-31,CA,295,39512223 -2024-01-01,CA,283,39512223 -2024-01-02,CA,333,39512223 -2024-01-03,CA,244,39512223 -2024-01-04,CA,271,39512223 -2024-01-05,CA,235,39512223 -2024-01-06,CA,237,39512223 -2024-01-07,CA,228,39512223 -2024-01-08,CA,221,39512223 -2024-01-09,CA,151,39512223 -2024-01-10,CA,168,39512223 -2024-01-11,CA,166,39512223 -2024-01-12,CA,133,39512223 -2024-01-13,CA,125,39512223 -2024-01-14,CA,105,39512223 -2024-01-15,CA,105,39512223 -2024-01-16,CA,95,39512223 -2024-01-17,CA,80,39512223 -2024-01-18,CA,79,39512223 -2024-01-19,CA,57,39512223 -2024-01-20,CA,60,39512223 -2024-01-21,CA,51,39512223 -2024-01-22,CA,49,39512223 -2024-01-23,CA,30,39512223 -2024-01-24,CA,43,39512223 -2024-01-25,CA,38,39512223 -2024-01-26,CA,39,39512223 -2024-01-27,CA,35,39512223 -2024-01-28,CA,25,39512223 -2024-01-29,CA,19,39512223 -2024-01-30,CA,20,39512223 -2024-01-31,CA,15,39512223 -2024-02-01,CA,20,39512223 -2024-02-02,CA,15,39512223 -2024-02-03,CA,15,39512223 -2024-02-04,CA,9,39512223 -2024-02-05,CA,12,39512223 -2024-02-06,CA,9,39512223 -2024-02-07,CA,16,39512223 -2024-02-08,CA,19,39512223 -2024-02-09,CA,22,39512223 -2024-02-10,CA,13,39512223 -2024-02-11,CA,15,39512223 -2024-02-12,CA,15,39512223 -2024-02-13,CA,13,39512223 -2024-02-14,CA,12,39512223 -2024-02-15,CA,13,39512223 -2024-02-16,CA,22,39512223 -2024-02-17,CA,19,39512223 -2024-02-18,CA,20,39512223 -2024-02-19,CA,15,39512223 -2024-02-20,CA,15,39512223 -2024-02-21,CA,15,39512223 -2024-02-22,CA,19,39512223 -2024-02-23,CA,16,39512223 -2024-02-24,CA,15,39512223 -2024-02-25,CA,23,39512223 -2024-02-26,CA,20,39512223 -2024-02-27,CA,19,39512223 -2024-02-28,CA,22,39512223 -2024-02-29,CA,22,39512223 -2024-03-01,CA,28,39512223 -2024-03-02,CA,21,39512223 -2024-03-03,CA,20,39512223 -2024-03-04,CA,27,39512223 diff --git a/pyrenew/datasets/synthetic_CA_120/daily_infections.csv b/pyrenew/datasets/synthetic_CA_120/daily_infections.csv deleted file mode 100644 index d22e74f1..00000000 --- a/pyrenew/datasets/synthetic_CA_120/daily_infections.csv +++ /dev/null @@ -1,121 +0,0 @@ -date,true_infections,true_rt -2023-11-06,19756.11054520949,1.2 -2023-11-07,22137.832061728463,1.1933333333333334 -2023-11-08,24713.05529723112,1.1866666666666665 -2023-11-09,27483.293945028956,1.18 -2023-11-10,30447.758878007327,1.1733333333333333 -2023-11-11,33603.09179301876,1.1666666666666665 -2023-11-12,36943.12084926061,1.16 -2023-11-13,40458.647149085045,1.1533333333333333 -2023-11-14,44137.271061809806,1.1466666666666667 -2023-11-15,47963.348602612365,1.14 -2023-11-16,51917.84793339851,1.1333333333333333 -2023-11-17,55978.37022993391,1.1266666666666667 -2023-11-18,60119.22285114966,1.1199999999999999 -2023-11-19,64311.56901533906,1.1133333333333333 -2023-11-20,68523.65715311568,1.1066666666666667 -2023-11-21,72721.13103142017,1.1 -2023-11-22,76867.41944068302,1.0933333333333333 -2023-11-23,80924.2017246116,1.0866666666666667 -2023-11-24,84851.94303922662,1.08 -2023-11-25,88610.49054041045,1.0733333333333333 -2023-11-26,92159.71934937873,1.0666666666666667 -2023-11-27,95460.2149212564,1.06 -2023-11-28,98473.97653721012,1.0533333333333332 -2023-11-29,101165.12513683562,1.0466666666666666 -2023-11-30,103500.59768685527,1.04 -2023-12-01,105450.80981140169,1.0333333333333332 -2023-12-02,106990.26853536024,1.0266666666666666 -2023-12-03,108098.11774003913,1.02 -2023-12-04,108758.6002989673,1.0133333333333334 -2023-12-05,108961.42282384375,1.0066666666666666 -2023-12-06,108702.01145320752,1.0 -2023-12-07,107981.65008103396,0.9933333333333333 -2023-12-08,106807.49574885165,0.9866666666666667 -2023-12-09,105192.46949463306,0.98 -2023-12-10,103155.02463303752,0.9733333333333334 -2023-12-11,100718.79809568585,0.9666666666666667 -2023-12-12,97912.15394717995,0.96 -2023-12-13,94767.63137826548,0.9533333333333334 -2023-12-14,91321.31223948221,0.9466666666666667 -2023-12-15,87612.1254121439,0.94 -2023-12-16,83681.10693644159,0.9333333333333333 -2023-12-17,79570.63577318197,0.9266666666666667 -2023-12-18,75323.66533931163,0.9199999999999999 -2023-12-19,70982.97053090445,0.9133333333333333 -2023-12-20,66590.42886290187,0.9066666666666667 -2023-12-21,62186.35267198246,0.9 -2023-12-22,57808.88713067756,0.8933333333333333 -2023-12-23,53493.48620967932,0.8866666666666667 -2023-12-24,49272.47581752002,0.88 -2023-12-25,45174.71026669215,0.8733333333333333 -2023-12-26,41225.32508890801,0.8666666666666667 -2023-12-27,37445.58617150747,0.8600000000000001 -2023-12-28,33852.8323243111,0.8533333333333333 -2023-12-29,30460.50580940868,0.8466666666666667 -2023-12-30,27278.263155260243,0.8400000000000001 -2023-12-31,24312.156790105855,0.8333333333333334 -2024-01-01,21564.876704915074,0.8266666666666667 -2024-01-02,19036.040507503185,0.8200000000000001 -2024-01-03,16722.519850180324,0.8133333333333334 -2024-01-04,14618.791277181628,0.8066666666666666 -2024-01-05,12717.300002393937,0.8 -2024-01-06,11205.412114101426,0.8075 -2024-01-07,9925.754783687535,0.8150000000000001 -2024-01-08,8838.84684474518,0.8225 -2024-01-09,7912.576007747334,0.8300000000000001 -2024-01-10,7120.727700272463,0.8375 -2024-01-11,6441.825882807316,0.8450000000000001 -2024-01-12,5858.216338610594,0.8525 -2024-01-13,5355.339128127326,0.8600000000000001 -2024-01-14,4921.050758957707,0.8675 -2024-01-15,4545.338518068657,0.875 -2024-01-16,4219.883804499359,0.8825000000000001 -2024-01-17,3937.758638554082,0.8900000000000001 -2024-01-18,3693.1812564367424,0.8975000000000001 -2024-01-19,3481.3187502359347,0.905 -2024-01-20,3298.127311750783,0.9125000000000001 -2024-01-21,3140.222655134437,0.92 -2024-01-22,3004.7748166431147,0.9275000000000001 -2024-01-23,2889.422557099414,0.935 -2024-01-24,2792.2038372695265,0.9425000000000001 -2024-01-25,2711.4994312736244,0.9500000000000001 -2024-01-26,2645.9873791490927,0.9575 -2024-01-27,2594.6064611590755,0.9650000000000001 -2024-01-28,2556.5272634637226,0.9725000000000001 -2024-01-29,2531.1297188751973,0.9800000000000001 -2024-01-30,2517.9862644961117,0.9875 -2024-01-31,2516.8499736912167,0.9950000000000001 -2024-02-01,2527.647203642624,1.0025000000000002 -2024-02-02,2550.4744608668757,1.01 -2024-02-03,2585.599332958829,1.0175 -2024-02-04,2633.4654720916783,1.0250000000000001 -2024-02-05,2694.7017504746477,1.0325000000000002 -2024-02-06,2770.13584583844,1.04 -2024-02-07,2860.8126619369277,1.0475 -2024-02-08,2968.0181512451327,1.0550000000000002 -2024-02-09,3093.3092914120125,1.0625 -2024-02-10,3238.551181553564,1.07 -2024-02-11,3405.962478571058,1.0775000000000001 -2024-02-12,3598.1706987346743,1.085 -2024-02-13,3818.2792797395227,1.0925 -2024-02-14,4069.948750608039,1.1 -2024-02-15,4278.804259714245,1.0875000000000001 -2024-02-16,4464.378292782625,1.0750000000000002 -2024-02-17,4622.498189584832,1.0625 -2024-02-18,4749.429210576598,1.05 -2024-02-19,4842.016091859674,1.0375 -2024-02-20,4897.808898382193,1.0250000000000001 -2024-02-21,4915.166344392091,1.0125 -2024-02-22,4893.330641693712,1.0 -2024-02-23,4832.51409477793,0.9875 -2024-02-24,4733.84928467361,0.9750000000000001 -2024-02-25,4599.3707190434,0.9625 -2024-02-26,4431.94639781228,0.95 -2024-02-27,4235.176372034393,0.9375 -2024-02-28,4013.263271199615,0.925 -2024-02-29,3770.8613226904004,0.9125 -2024-03-01,3512.911514102406,0.9 -2024-03-02,3244.47115210662,0.8875 -2024-03-03,2970.5462345139176,0.875 -2024-03-04,2695.934508618182,0.8625 diff --git a/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv deleted file mode 100644 index 0b1d3b8f..00000000 --- a/pyrenew/datasets/synthetic_CA_120/weekly_hospital_admissions.csv +++ /dev/null @@ -1,17 +0,0 @@ -week_end,weekly_hosp_admits,location,pop -2023-11-18,667,CA,39512223 -2023-11-25,1404,CA,39512223 -2023-12-02,2206,CA,39512223 -2023-12-09,2883,CA,39512223 -2023-12-16,3464,CA,39512223 -2023-12-23,3410,CA,39512223 -2023-12-30,2828,CA,39512223 -2024-01-06,1898,CA,39512223 -2024-01-13,1192,CA,39512223 -2024-01-20,581,CA,39512223 -2024-01-27,285,CA,39512223 -2024-02-03,129,CA,39512223 -2024-02-10,100,CA,39512223 -2024-02-17,109,CA,39512223 -2024-02-24,115,CA,39512223 -2024-03-02,155,CA,39512223 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv b/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv new file mode 100644 index 00000000..3f4c4651 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_ed_visits.csv @@ -0,0 +1,127 @@ +date,geo_value,disease,ed_visits +2023-11-05,CA,COVID-19,95 +2023-11-06,CA,COVID-19,125 +2023-11-07,CA,COVID-19,138 +2023-11-08,CA,COVID-19,136 +2023-11-09,CA,COVID-19,172 +2023-11-10,CA,COVID-19,161 +2023-11-11,CA,COVID-19,155 +2023-11-12,CA,COVID-19,133 +2023-11-13,CA,COVID-19,273 +2023-11-14,CA,COVID-19,263 +2023-11-15,CA,COVID-19,360 +2023-11-16,CA,COVID-19,296 +2023-11-17,CA,COVID-19,202 +2023-11-18,CA,COVID-19,272 +2023-11-19,CA,COVID-19,312 +2023-11-20,CA,COVID-19,510 +2023-11-21,CA,COVID-19,448 +2023-11-22,CA,COVID-19,591 +2023-11-23,CA,COVID-19,529 +2023-11-24,CA,COVID-19,476 +2023-11-25,CA,COVID-19,403 +2023-11-26,CA,COVID-19,536 +2023-11-27,CA,COVID-19,603 +2023-11-28,CA,COVID-19,1160 +2023-11-29,CA,COVID-19,989 +2023-11-30,CA,COVID-19,769 +2023-12-01,CA,COVID-19,901 +2023-12-02,CA,COVID-19,750 +2023-12-03,CA,COVID-19,617 +2023-12-04,CA,COVID-19,651 +2023-12-05,CA,COVID-19,729 +2023-12-06,CA,COVID-19,885 +2023-12-07,CA,COVID-19,1026 +2023-12-08,CA,COVID-19,792 +2023-12-09,CA,COVID-19,675 +2023-12-10,CA,COVID-19,772 +2023-12-11,CA,COVID-19,896 +2023-12-12,CA,COVID-19,623 +2023-12-13,CA,COVID-19,786 +2023-12-14,CA,COVID-19,777 +2023-12-15,CA,COVID-19,904 +2023-12-16,CA,COVID-19,561 +2023-12-17,CA,COVID-19,531 +2023-12-18,CA,COVID-19,983 +2023-12-19,CA,COVID-19,973 +2023-12-20,CA,COVID-19,583 +2023-12-21,CA,COVID-19,504 +2023-12-22,CA,COVID-19,660 +2023-12-23,CA,COVID-19,647 +2023-12-24,CA,COVID-19,490 +2023-12-25,CA,COVID-19,721 +2023-12-26,CA,COVID-19,659 +2023-12-27,CA,COVID-19,479 +2023-12-28,CA,COVID-19,352 +2023-12-29,CA,COVID-19,348 +2023-12-30,CA,COVID-19,320 +2023-12-31,CA,COVID-19,229 +2024-01-01,CA,COVID-19,269 +2024-01-02,CA,COVID-19,309 +2024-01-03,CA,COVID-19,326 +2024-01-04,CA,COVID-19,277 +2024-01-05,CA,COVID-19,235 +2024-01-06,CA,COVID-19,167 +2024-01-07,CA,COVID-19,116 +2024-01-08,CA,COVID-19,168 +2024-01-09,CA,COVID-19,129 +2024-01-10,CA,COVID-19,165 +2024-01-11,CA,COVID-19,91 +2024-01-12,CA,COVID-19,91 +2024-01-13,CA,COVID-19,54 +2024-01-14,CA,COVID-19,68 +2024-01-15,CA,COVID-19,88 +2024-01-16,CA,COVID-19,82 +2024-01-17,CA,COVID-19,64 +2024-01-18,CA,COVID-19,47 +2024-01-19,CA,COVID-19,40 +2024-01-20,CA,COVID-19,30 +2024-01-21,CA,COVID-19,28 +2024-01-22,CA,COVID-19,34 +2024-01-23,CA,COVID-19,27 +2024-01-24,CA,COVID-19,27 +2024-01-25,CA,COVID-19,43 +2024-01-26,CA,COVID-19,29 +2024-01-27,CA,COVID-19,21 +2024-01-28,CA,COVID-19,20 +2024-01-29,CA,COVID-19,28 +2024-01-30,CA,COVID-19,20 +2024-01-31,CA,COVID-19,23 +2024-02-01,CA,COVID-19,19 +2024-02-02,CA,COVID-19,23 +2024-02-03,CA,COVID-19,16 +2024-02-04,CA,COVID-19,15 +2024-02-05,CA,COVID-19,13 +2024-02-06,CA,COVID-19,13 +2024-02-07,CA,COVID-19,13 +2024-02-08,CA,COVID-19,13 +2024-02-09,CA,COVID-19,11 +2024-02-10,CA,COVID-19,9 +2024-02-11,CA,COVID-19,13 +2024-02-12,CA,COVID-19,9 +2024-02-13,CA,COVID-19,36 +2024-02-14,CA,COVID-19,25 +2024-02-15,CA,COVID-19,18 +2024-02-16,CA,COVID-19,26 +2024-02-17,CA,COVID-19,18 +2024-02-18,CA,COVID-19,23 +2024-02-19,CA,COVID-19,25 +2024-02-20,CA,COVID-19,34 +2024-02-21,CA,COVID-19,20 +2024-02-22,CA,COVID-19,21 +2024-02-23,CA,COVID-19,32 +2024-02-24,CA,COVID-19,24 +2024-02-25,CA,COVID-19,27 +2024-02-26,CA,COVID-19,36 +2024-02-27,CA,COVID-19,44 +2024-02-28,CA,COVID-19,39 +2024-02-29,CA,COVID-19,46 +2024-03-01,CA,COVID-19,23 +2024-03-02,CA,COVID-19,35 +2024-03-03,CA,COVID-19,29 +2024-03-04,CA,COVID-19,34 +2024-03-05,CA,COVID-19,27 +2024-03-06,CA,COVID-19,39 +2024-03-07,CA,COVID-19,31 +2024-03-08,CA,COVID-19,27 +2024-03-09,CA,COVID-19,13 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv new file mode 100644 index 00000000..b36da6d5 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_hospital_admissions.csv @@ -0,0 +1,127 @@ +date,geo_value,daily_hosp_admits,pop +2023-11-05,CA,28,39512223 +2023-11-06,CA,31,39512223 +2023-11-07,CA,43,39512223 +2023-11-08,CA,30,39512223 +2023-11-09,CA,57,39512223 +2023-11-10,CA,47,39512223 +2023-11-11,CA,66,39512223 +2023-11-12,CA,94,39512223 +2023-11-13,CA,66,39512223 +2023-11-14,CA,81,39512223 +2023-11-15,CA,93,39512223 +2023-11-16,CA,122,39512223 +2023-11-17,CA,146,39512223 +2023-11-18,CA,119,39512223 +2023-11-19,CA,176,39512223 +2023-11-20,CA,209,39512223 +2023-11-21,CA,199,39512223 +2023-11-22,CA,225,39512223 +2023-11-23,CA,228,39512223 +2023-11-24,CA,260,39512223 +2023-11-25,CA,280,39512223 +2023-11-26,CA,281,39512223 +2023-11-27,CA,306,39512223 +2023-11-28,CA,328,39512223 +2023-11-29,CA,336,39512223 +2023-11-30,CA,361,39512223 +2023-12-01,CA,365,39512223 +2023-12-02,CA,403,39512223 +2023-12-03,CA,457,39512223 +2023-12-04,CA,433,39512223 +2023-12-05,CA,399,39512223 +2023-12-06,CA,434,39512223 +2023-12-07,CA,449,39512223 +2023-12-08,CA,442,39512223 +2023-12-09,CA,507,39512223 +2023-12-10,CA,524,39512223 +2023-12-11,CA,552,39512223 +2023-12-12,CA,528,39512223 +2023-12-13,CA,460,39512223 +2023-12-14,CA,554,39512223 +2023-12-15,CA,611,39512223 +2023-12-16,CA,506,39512223 +2023-12-17,CA,546,39512223 +2023-12-18,CA,598,39512223 +2023-12-19,CA,566,39512223 +2023-12-20,CA,554,39512223 +2023-12-21,CA,490,39512223 +2023-12-22,CA,566,39512223 +2023-12-23,CA,537,39512223 +2023-12-24,CA,491,39512223 +2023-12-25,CA,491,39512223 +2023-12-26,CA,445,39512223 +2023-12-27,CA,424,39512223 +2023-12-28,CA,470,39512223 +2023-12-29,CA,467,39512223 +2023-12-30,CA,373,39512223 +2023-12-31,CA,449,39512223 +2024-01-01,CA,347,39512223 +2024-01-02,CA,317,39512223 +2024-01-03,CA,361,39512223 +2024-01-04,CA,290,39512223 +2024-01-05,CA,269,39512223 +2024-01-06,CA,238,39512223 +2024-01-07,CA,224,39512223 +2024-01-08,CA,201,39512223 +2024-01-09,CA,185,39512223 +2024-01-10,CA,199,39512223 +2024-01-11,CA,179,39512223 +2024-01-12,CA,168,39512223 +2024-01-13,CA,143,39512223 +2024-01-14,CA,142,39512223 +2024-01-15,CA,128,39512223 +2024-01-16,CA,109,39512223 +2024-01-17,CA,108,39512223 +2024-01-18,CA,79,39512223 +2024-01-19,CA,81,39512223 +2024-01-20,CA,69,39512223 +2024-01-21,CA,66,39512223 +2024-01-22,CA,42,39512223 +2024-01-23,CA,57,39512223 +2024-01-24,CA,50,39512223 +2024-01-25,CA,50,39512223 +2024-01-26,CA,45,39512223 +2024-01-27,CA,33,39512223 +2024-01-28,CA,25,39512223 +2024-01-29,CA,25,39512223 +2024-01-30,CA,19,39512223 +2024-01-31,CA,23,39512223 +2024-02-01,CA,18,39512223 +2024-02-02,CA,17,39512223 +2024-02-03,CA,11,39512223 +2024-02-04,CA,13,39512223 +2024-02-05,CA,10,39512223 +2024-02-06,CA,17,39512223 +2024-02-07,CA,9,39512223 +2024-02-08,CA,13,39512223 +2024-02-09,CA,14,39512223 +2024-02-10,CA,14,39512223 +2024-02-11,CA,14,39512223 +2024-02-12,CA,11,39512223 +2024-02-13,CA,10,39512223 +2024-02-14,CA,11,39512223 +2024-02-15,CA,20,39512223 +2024-02-16,CA,17,39512223 +2024-02-17,CA,17,39512223 +2024-02-18,CA,12,39512223 +2024-02-19,CA,12,39512223 +2024-02-20,CA,11,39512223 +2024-02-21,CA,14,39512223 +2024-02-22,CA,12,39512223 +2024-02-23,CA,11,39512223 +2024-02-24,CA,18,39512223 +2024-02-25,CA,15,39512223 +2024-02-26,CA,15,39512223 +2024-02-27,CA,18,39512223 +2024-02-28,CA,23,39512223 +2024-02-29,CA,13,39512223 +2024-03-01,CA,16,39512223 +2024-03-02,CA,23,39512223 +2024-03-03,CA,13,39512223 +2024-03-04,CA,24,39512223 +2024-03-05,CA,23,39512223 +2024-03-06,CA,26,39512223 +2024-03-07,CA,25,39512223 +2024-03-08,CA,19,39512223 +2024-03-09,CA,21,39512223 diff --git a/pyrenew/datasets/synthetic_CA_126/daily_infections.csv b/pyrenew/datasets/synthetic_CA_126/daily_infections.csv new file mode 100644 index 00000000..461579ba --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/daily_infections.csv @@ -0,0 +1,127 @@ +date,true_infections,true_rt +2023-11-05,19756.11054520949,1.2 +2023-11-06,22143.721349321364,1.1936507936507936 +2023-11-07,24730.701954748823,1.1873015873015873 +2023-11-08,27520.279922418667,1.180952380952381 +2023-11-09,30513.614670228315,1.1746031746031746 +2023-11-10,33709.526838297024,1.1682539682539683 +2023-11-11,37104.241812393724,1.1619047619047618 +2023-11-12,40691.15530610575,1.1555555555555554 +2023-11-13,44460.6292337521,1.1492063492063491 +2023-11-14,48399.903877935554,1.1428571428571428 +2023-11-15,52492.90723467809,1.1365079365079365 +2023-11-16,56720.213937022985,1.1301587301587301 +2023-11-17,61059.047787575444,1.1238095238095238 +2023-11-18,65483.352178571426,1.1174603174603175 +2023-11-19,69963.9330002031,1.1111111111111112 +2023-11-20,74468.67701680085,1.1047619047619048 +2023-11-21,78962.84682237565,1.0984126984126985 +2023-11-22,83409.4513679222,1.092063492063492 +2023-11-23,87769.68897968013,1.0857142857142856 +2023-11-24,92003.45734561312,1.0793650793650793 +2023-11-25,96069.92272472705,1.073015873015873 +2023-11-26,99928.13841545957,1.0666666666666667 +2023-11-27,103537.7004820302,1.0603174603174603 +2023-11-28,106859.4269560898,1.053968253968254 +2023-11-29,109856.04528657829,1.0476190476190477 +2023-11-30,112492.87177639856,1.0412698412698413 +2023-12-01,114738.46618109582,1.0349206349206348 +2023-12-02,116565.24459766244,1.0285714285714285 +2023-12-03,117950.03426758068,1.0222222222222221 +2023-12-04,118874.55496352536,1.0158730158730158 +2023-12-05,119325.81320807226,1.0095238095238095 +2023-12-06,119296.39764740864,1.0031746031746032 +2023-12-07,118784.66641411892,0.9968253968253968 +2023-12-08,117794.82018176407,0.9904761904761905 +2023-12-09,116336.85774463006,0.9841269841269842 +2023-12-10,114426.41424020738,0.9777777777777779 +2023-12-11,112084.48545277533,0.9714285714285714 +2023-12-12,109337.04487359013,0.9650793650793651 +2023-12-13,106214.56322819905,0.9587301587301588 +2023-12-14,102751.4429030742,0.9523809523809524 +2023-12-15,98985.38201301648,0.946031746031746 +2023-12-16,94956.68466518822,0.9396825396825397 +2023-12-17,90707.5352330704,0.9333333333333333 +2023-12-18,86281.25511496875,0.926984126984127 +2023-12-19,81721.56050233734,0.9206349206349207 +2023-12-20,77071.83913347754,0.9142857142857144 +2023-12-21,72374.46289250287,0.907936507936508 +2023-12-22,67670.15148828148,0.9015873015873016 +2023-12-23,62997.40038884645,0.8952380952380953 +2023-12-24,58391.98378409807,0.888888888888889 +2023-12-25,53886.5407047299,0.8825396825396825 +2023-12-26,49510.249645148186,0.8761904761904762 +2023-12-27,45288.594230426665,0.8698412698412699 +2023-12-28,41243.219735682826,0.8634920634920635 +2023-12-29,37391.87770582045,0.8571428571428572 +2023-12-30,33748.45361723469,0.8507936507936509 +2023-12-31,30323.07053841518,0.8444444444444446 +2024-01-01,27122.260133653966,0.8380952380952382 +2024-01-02,24149.191145027726,0.8317460317460318 +2024-01-03,21403.944695610335,0.8253968253968255 +2024-01-04,18883.82537674655,0.8190476190476191 +2024-01-05,16583.69709315589,0.8126984126984127 +2024-01-06,14496.333006565756,0.8063492063492064 +2024-01-07,12612.769595127733,0.8 +2024-01-08,11108.34092566445,0.8071428571428572 +2024-01-09,9832.930973744562,0.8142857142857143 +2024-01-10,8747.948967868044,0.8214285714285715 +2024-01-11,7821.924519498369,0.8285714285714286 +2024-01-12,7029.102928870267,0.8357142857142857 +2024-01-13,6348.334538289707,0.8428571428571429 +2024-01-14,5762.19348424236,0.8500000000000001 +2024-01-15,5256.276056958226,0.8571428571428572 +2024-01-16,4818.547575787828,0.8642857142857143 +2024-01-17,4439.060337562475,0.8714285714285714 +2024-01-18,4109.530594712926,0.8785714285714287 +2024-01-19,3823.0425101568535,0.8857142857142858 +2024-01-20,3573.8090602874067,0.8928571428571429 +2024-01-21,3356.978392072291,0.9 +2024-01-22,3168.476589950405,0.9071428571428573 +2024-01-23,3004.8797151710523,0.9142857142857144 +2024-01-24,2863.309519082173,0.9214285714285715 +2024-01-25,2741.348211451286,0.9285714285714286 +2024-01-26,2636.96884685248,0.9357142857142857 +2024-01-27,2548.4784620555247,0.942857142857143 +2024-01-28,2474.471704111163,0.9500000000000001 +2024-01-29,2413.7931481554974,0.9571428571428572 +2024-01-30,2365.5068722504175,0.9642857142857144 +2024-01-31,2328.872153958227,0.9714285714285715 +2024-02-01,2303.3243955688663,0.9785714285714286 +2024-02-02,2288.460584558283,0.9857142857142858 +2024-02-03,2284.028762651382,0.9928571428571429 +2024-02-04,2289.921119069131,1.0 +2024-02-05,2306.170447669907,1.0071428571428571 +2024-02-06,2332.9498193450536,1.0142857142857142 +2024-02-07,2370.5754250480477,1.0214285714285716 +2024-02-08,2419.5126456106736,1.0285714285714287 +2024-02-09,2480.3855061775553,1.0357142857142858 +2024-02-10,2553.9897797971385,1.042857142857143 +2024-02-11,2641.3101207486616,1.05 +2024-02-12,2743.5417382594255,1.0571428571428572 +2024-02-13,2862.1172706974176,1.0642857142857143 +2024-02-14,2998.739695320636,1.0714285714285716 +2024-02-15,3155.4223166337065,1.0785714285714287 +2024-02-16,3334.537126326037,1.0857142857142859 +2024-02-17,3538.873130663702,1.092857142857143 +2024-02-18,3771.7066107178266,1.1 +2024-02-19,3967.4339363506274,1.0880952380952382 +2024-02-20,4143.30298203913,1.0761904761904764 +2024-02-21,4295.597223486165,1.0642857142857143 +2024-02-22,4420.940146912601,1.0523809523809524 +2024-02-23,4516.41535496139,1.0404761904761906 +2024-02-24,4579.67554658712,1.0285714285714287 +2024-02-25,4609.03482820428,1.0166666666666666 +2024-02-26,4603.539398956749,1.0047619047619047 +2024-02-27,4563.052362067281,0.9928571428571429 +2024-02-28,4488.221220621856,0.980952380952381 +2024-02-29,4380.476066032479,0.969047619047619 +2024-03-01,4241.985599243413,0.9571428571428572 +2024-03-02,4075.5851128393456,0.9452380952380952 +2024-03-03,3884.679728108972,0.9333333333333333 +2024-03-04,3673.127496098455,0.9214285714285715 +2024-03-05,3445.108006398925,0.9095238095238095 +2024-03-06,3204.982810520835,0.8976190476190476 +2024-03-07,2957.1543245741104,0.8857142857142857 +2024-03-08,2705.9296732363823,0.8738095238095238 +2024-03-09,2455.3954664262037,0.861904761904762 diff --git a/pyrenew/datasets/synthetic_CA_120/true_parameters.json b/pyrenew/datasets/synthetic_CA_126/true_parameters.json similarity index 86% rename from pyrenew/datasets/synthetic_CA_120/true_parameters.json rename to pyrenew/datasets/synthetic_CA_126/true_parameters.json index fecf1091..a95fe24c 100644 --- a/pyrenew/datasets/synthetic_CA_120/true_parameters.json +++ b/pyrenew/datasets/synthetic_CA_126/true_parameters.json @@ -1,8 +1,8 @@ { "description": "True parameters used to generate synthetic 120-day CA data. All values are known ground truth for posterior recovery checks.", "population": 39512223, - "start_date": "2023-11-06", - "n_days": 120, + "start_date": "2023-11-05", + "n_days": 126, "n_init": 55, "rng_seed": 20240101, "generation_interval_pmf": [ @@ -17,17 +17,17 @@ "i0_per_capita": 0.0005, "rt_trajectory": { "phase_1": { - "days": 60, + "days": 63, "start": 1.2, "end": 0.8 }, "phase_2": { - "days": 40, + "days": 42, "start": 0.8, "end": 1.1 }, "phase_3": { - "days": 20, + "days": 21, "start": 1.1, "end": 0.85 } @@ -35,7 +35,8 @@ "hospitalizations": { "ihr": 0.005, "delay_pmf_source": "infection_admission_interval.tsv", - "negbinom_concentration": 350.0, + "negbinom_concentration_daily": 350.0, + "negbinom_concentration_weekly": 100.0, "temporal_resolutions": [ "daily", "weekly_epiweek" diff --git a/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv b/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv new file mode 100644 index 00000000..efc9c9d2 --- /dev/null +++ b/pyrenew/datasets/synthetic_CA_126/weekly_hospital_admissions.csv @@ -0,0 +1,19 @@ +week_end,weekly_hosp_admits,location,pop +2023-11-11,341,CA,39512223 +2023-11-18,624,CA,39512223 +2023-11-25,1223,CA,39512223 +2023-12-02,2450,CA,39512223 +2023-12-09,3227,CA,39512223 +2023-12-16,4341,CA,39512223 +2023-12-23,4710,CA,39512223 +2023-12-30,3924,CA,39512223 +2024-01-06,2067,CA,39512223 +2024-01-13,1627,CA,39512223 +2024-01-20,625,CA,39512223 +2024-01-27,336,CA,39512223 +2024-02-03,151,CA,39512223 +2024-02-10,93,CA,39512223 +2024-02-17,70,CA,39512223 +2024-02-24,79,CA,39512223 +2024-03-02,114,CA,39512223 +2024-03-09,147,CA,39512223 diff --git a/pyrenew/datasets/synthetic_data.py b/pyrenew/datasets/synthetic_data.py index 13bcfd37..9b944ceb 100644 --- a/pyrenew/datasets/synthetic_data.py +++ b/pyrenew/datasets/synthetic_data.py @@ -1,7 +1,7 @@ """ -Loaders for synthetic 120-day California test data. +Loaders for synthetic 126-day California test data. -The synthetic dataset is generated by ``datagen_he_CA_120.py`` and contains +The synthetic dataset is generated by ``datagen_he_CA_126.py`` and contains daily infections, hospital admissions, ED visits, and ground-truth parameters for a single-jurisdiction COVID-19 scenario. """ @@ -16,14 +16,14 @@ def _synthetic_path() -> files: """ - Return the path to the synthetic_CA_120 data directory. + Return the path to the synthetic_CA_126 data directory. Returns ------- importlib.resources.abc.Traversable - Path to the synthetic_CA_120 package directory. + Path to the synthetic_CA_126 package directory. """ - return files("pyrenew.datasets") / "synthetic_CA_120" + return files("pyrenew.datasets") / "synthetic_CA_126" def load_synthetic_true_parameters() -> dict: diff --git a/pyrenew/latent/__init__.py b/pyrenew/latent/__init__.py index 45f6f06b..aecd5db1 100644 --- a/pyrenew/latent/__init__.py +++ b/pyrenew/latent/__init__.py @@ -32,6 +32,7 @@ AR1, DifferencedAR1, RandomWalk, + StepwiseTemporalProcess, TemporalProcess, ) @@ -61,4 +62,5 @@ "AR1", "DifferencedAR1", "RandomWalk", + "StepwiseTemporalProcess", ] diff --git a/pyrenew/latent/base.py b/pyrenew/latent/base.py index c84d393a..1b5a1eda 100644 --- a/pyrenew/latent/base.py +++ b/pyrenew/latent/base.py @@ -12,6 +12,7 @@ from jax.typing import ArrayLike from numpyro.util import not_jax_tracer +from pyrenew.latent.temporal_processes import TemporalProcess from pyrenew.metaclass import RandomVariable @@ -338,6 +339,31 @@ def get_required_lookback(self) -> int: """ return len(self.gen_int_rv()) + def requires_calendar_anchor(self) -> bool: + """ + Report whether this latent process needs a calendar anchor at sample time. + + The default implementation inspects this instance's attributes + for :class:`pyrenew.latent.temporal_processes.TemporalProcess` + instances and returns ``True`` if any of them has + ``alignment="calendar_week"``. Subclasses may override to add + additional calendar-aligned components. + + Returns + ------- + bool + ``True`` if the caller of :meth:`sample` must supply a + ``first_day_dow`` (derived from ``obs_start_date`` at the + model entry point); ``False`` otherwise. + """ + for value in vars(self).values(): + if ( + isinstance(value, TemporalProcess) + and getattr(value, "alignment", None) == "calendar_week" + ): + return True + return False + @abstractmethod def sample( self, diff --git a/pyrenew/latent/population_infections.py b/pyrenew/latent/population_infections.py index 1d1ce1cf..e6feb7c3 100644 --- a/pyrenew/latent/population_infections.py +++ b/pyrenew/latent/population_infections.py @@ -9,6 +9,7 @@ from jax.typing import ArrayLike from numpyro.util import not_jax_tracer +from pyrenew.arrayutils import require_shape from pyrenew.deterministic import DeterministicVariable from pyrenew.distutil import validate_discrete_dist_vector from pyrenew.latent.base import ( @@ -147,13 +148,18 @@ def sample( self, n_days_post_init: int, subpop_fractions: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> LatentSample: """ Sample population infections using a single renewal process. - Generates a single $\\mathcal{R}(t)$ trajectory, computes initial infections via - exponential backprojection, and runs one renewal equation. + Generates a single daily $\\mathcal{R}(t)$ trajectory, computes initial + infections via exponential backprojection, and runs one deterministic + daily renewal equation. The temporal process may sample parameters at + any supported cadence (for example daily or weekly/stepwise), but it + must return a daily-length trajectory before the renewal equation is + evaluated. Parameters ---------- @@ -162,6 +168,9 @@ def sample( subpop_fractions Population fractions. Defaults to ``[1.0]`` (single population). Must be ``[1.0]`` if provided. + first_day_dow + Forwarded to ``single_rt_process``. See + [pyrenew.latent.TemporalProcess][]. **kwargs Additional arguments (unused, for compatibility) @@ -196,7 +205,9 @@ def sample( n_timepoints=n_total_days, initial_value=initial_log_rt, name_prefix="log_rt_single", + first_day_dow=first_day_dow, ) + require_shape(log_rt_single, (n_total_days, 1), "single_rt_process") rt_single = jnp.exp(log_rt_single) diff --git a/pyrenew/latent/subpopulation_infections.py b/pyrenew/latent/subpopulation_infections.py index 613f75f8..453365e3 100644 --- a/pyrenew/latent/subpopulation_infections.py +++ b/pyrenew/latent/subpopulation_infections.py @@ -11,6 +11,7 @@ import numpyro from jax.typing import ArrayLike +from pyrenew.arrayutils import require_shape from pyrenew.deterministic import DeterministicVariable from pyrenew.distutil import validate_discrete_dist_vector from pyrenew.latent.base import ( @@ -162,13 +163,18 @@ def sample( self, n_days_post_init: int, subpop_fractions: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> LatentSample: """ Sample hierarchical infections for all subpopulations. - Generates baseline $\\mathcal{R}(t)$, subpopulation deviations with sum-to-zero - constraint, initial infections, and runs n_subpops independent renewal processes. + Generates daily baseline $\\mathcal{R}(t)$, daily subpopulation + deviations with a sum-to-zero constraint, initial infections, and runs + n_subpops deterministic daily renewal processes. The temporal + processes may sample parameters at any supported cadence (for example + daily or weekly/stepwise), but they must return daily-length + trajectories before the renewal equations are evaluated. Parameters ---------- @@ -177,6 +183,10 @@ def sample( subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). Must sum to 1.0. + first_day_dow + Forwarded to ``baseline_rt_process`` and + ``subpop_rt_deviation_process``. See + [pyrenew.latent.TemporalProcess][]. **kwargs Additional arguments (unused, for compatibility) @@ -199,13 +209,21 @@ def sample( n_timepoints=n_total_days, initial_value=initial_log_rt, name_prefix="log_rt_baseline", + first_day_dow=first_day_dow, ) + require_shape(log_rt_baseline, (n_total_days, 1), "baseline_rt_process") deviations_raw = self.subpop_rt_deviation_process.sample( n_timepoints=n_total_days, n_processes=pop.n_subpops, initial_value=jnp.zeros(pop.n_subpops), name_prefix="subpop_deviations", + first_day_dow=first_day_dow, + ) + require_shape( + deviations_raw, + (n_total_days, pop.n_subpops), + "subpop_rt_deviation_process", ) # Sum-to-zero constraint ensures identifiability diff --git a/pyrenew/latent/temporal_processes.py b/pyrenew/latent/temporal_processes.py index 4020564f..9b44fab2 100644 --- a/pyrenew/latent/temporal_processes.py +++ b/pyrenew/latent/temporal_processes.py @@ -39,6 +39,9 @@ while stabilizing the growth rate. Wraps [pyrenew.process.DifferencedProcess][]. - ``RandomWalk``: No mean reversion. Rt can drift without bound. Wraps [pyrenew.process.RandomWalk][]. +- ``StepwiseTemporalProcess``: Wrapper that parameterizes any inner + ``TemporalProcess`` at a coarser cadence and broadcasts to the per-timepoint + scale by repetition. All implementations satisfy the ``TemporalProcess`` protocol and can be used interchangeably in hierarchical infection models. @@ -46,7 +49,7 @@ from __future__ import annotations -from typing import Protocol, runtime_checkable +from typing import Literal, Protocol, runtime_checkable import jax.numpy as jnp import numpyro @@ -56,6 +59,7 @@ from pyrenew.process import ARProcess, DifferencedProcess from pyrenew.process.randomwalk import RandomWalk as ProcessRandomWalk from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import WeekCycle, validate_dow, weekly_to_daily @runtime_checkable @@ -66,14 +70,27 @@ class TemporalProcess(Protocol): Used for jurisdiction-level Rt dynamics, subpopulation deviations, or allocation trajectories. All processes return 2D arrays of shape (n_timepoints, n_processes) for consistent handling. + + Attributes + ---------- + step_size : int + Number of consecutive timepoints that share the same sampled value. + Defaults to ``1`` for the standard processes (one independent sample + per timepoint). Wrapper processes like ``StepwiseTemporalProcess`` + expose a larger value so that model builders can enforce coherence + between R(t) parametrization cadence and observation aggregation. """ + step_size: int + def sample( self, n_timepoints: int, initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "temporal", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample temporal trajectory or trajectories. @@ -90,6 +107,10 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample site names to avoid collisions + first_day_dow + Day of week for element 0 of the shared model time axis + (0=Monday, ..., 6=Sunday). Standard temporal processes ignore + this value; calendar-aligned wrappers may use it. Returns ------- @@ -120,6 +141,8 @@ class AR1(TemporalProcess): more volatile trajectories; smaller values produce smoother ones. """ + step_size: int = 1 + def __init__(self, autoreg: float, innovation_sd: float = 1.0) -> None: """ Initialize AR(1) process. @@ -153,6 +176,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "ar1", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample AR(1) trajectory or trajectories. @@ -167,6 +192,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -220,6 +247,8 @@ class DifferencedAR1(TemporalProcess): more erratic growth rates; smaller values produce smoother trends. """ + step_size: int = 1 + def __init__(self, autoreg: float, innovation_sd: float = 1.0) -> None: """ Initialize differenced AR(1) process. @@ -258,6 +287,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "diff_ar1", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample differenced AR(1) trajectory or trajectories. @@ -272,6 +303,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -331,6 +364,8 @@ class RandomWalk(TemporalProcess): (``{name_prefix}_step``) via ``numpyro.handlers.reparam``. """ + step_size: int = 1 + def __init__(self, innovation_sd: float = 1.0) -> None: """ Initialize random walk process. @@ -359,6 +394,8 @@ def sample( initial_value: float | ArrayLike | None = None, n_processes: int = 1, name_prefix: str = "rw", + *, + first_day_dow: int | None = None, ) -> ArrayLike: """ Sample random walk trajectory or trajectories. @@ -373,6 +410,8 @@ def sample( Number of parallel processes. name_prefix Prefix for numpyro sample sites + first_day_dow + Unused. See [pyrenew.latent.TemporalProcess][]. Returns ------- @@ -399,3 +438,184 @@ def sample( init_vals=initial_value[jnp.newaxis, :], n=n_timepoints, ) + + +class StepwiseTemporalProcess(TemporalProcess): + """ + Parameterize an inner temporal process at a coarser cadence and + broadcast to the per-timepoint scale by repetition. + + Each ``step_size`` consecutive output timepoints share a single sampled + value from the inner process. Use when a parameter should be estimated at + a coarser cadence while downstream model components still need one value + per evaluation timepoint. For example, a weekly-parameterized R(t) process + can return daily R(t) values for a daily deterministic renewal equation. + + Parameters + ---------- + inner + Inner ``TemporalProcess`` that generates the coarse-scale + trajectory. Must satisfy the ``TemporalProcess`` protocol. + step_size + Number of per-timepoint units that share each inner sample. + Must be a positive integer. + alignment + How repeated blocks align to the output time axis. ``"model_index"`` + starts blocks at output index 0. ``"calendar_week"`` aligns blocks + to calendar weeks using ``week`` and the ``first_day_dow`` + supplied to ``sample()``. + week + Calendar-week anchor used when ``alignment="calendar_week"`` + (e.g., :data:`pyrenew.time.MMWR_WEEK`, + :data:`pyrenew.time.ISO_WEEK`). Required for calendar-week + alignment; must be ``None`` otherwise. + + Raises + ------ + ValueError + If ``step_size`` is not a positive integer, or if alignment + arguments are inconsistent. + """ + + _SUPPORTED_ALIGNMENTS = {"model_index", "calendar_week"} + + def __init__( + self, + inner: TemporalProcess, + step_size: int, + alignment: Literal["model_index", "calendar_week"] = "model_index", + week: WeekCycle | None = None, + ) -> None: + """ + Initialize stepwise temporal process. + + Parameters + ---------- + inner + Inner ``TemporalProcess`` that generates the coarse trajectory. + step_size + Number of per-timepoint units that share each inner sample. + alignment + How repeated blocks align to the output time axis. ``"model_index"`` + starts blocks at output index 0. ``"calendar_week"`` aligns + weekly blocks to ``week`` using ``first_day_dow`` at + sample time. + week + Calendar-week anchor used when + ``alignment="calendar_week"``. + + Raises + ------ + ValueError + If ``step_size`` is not a positive integer, or if alignment + arguments are inconsistent. + """ + if not isinstance(step_size, int) or step_size < 1: + raise ValueError(f"step_size must be a positive integer, got {step_size!r}") + if alignment not in self._SUPPORTED_ALIGNMENTS: + raise ValueError( + f"alignment must be one of {self._SUPPORTED_ALIGNMENTS}, " + f"got {alignment!r}" + ) + if alignment == "calendar_week": + if step_size != 7: + raise ValueError( + "calendar_week alignment requires step_size=7, " + f"got step_size={step_size}" + ) + if week is None: + raise ValueError("week is required when alignment='calendar_week'") + elif week is not None: + raise ValueError("week is only used when alignment='calendar_week'") + self.inner = inner + self.step_size = step_size + self.alignment = alignment + self.week = week + + def __repr__(self) -> str: + """Return string representation.""" + return ( + f"StepwiseTemporalProcess(inner={self.inner!r}, " + f"step_size={self.step_size}, alignment={self.alignment!r}, " + f"week={self.week!r})" + ) + + def _resolve_n_coarse(self, n_timepoints: int, first_day_dow: int | None) -> int: + """ + Return the number of inner-process samples needed. + + Returns + ------- + int + Number of coarse samples required to cover ``n_timepoints`` under + the configured alignment. + """ + if self.alignment == "model_index": + return (n_timepoints + self.step_size - 1) // self.step_size + if first_day_dow is None: + raise ValueError( + "first_day_dow is required at sample time when " + "alignment='calendar_week'" + ) + validate_dow(first_day_dow, "first_day_dow") + trim = (first_day_dow - self.week.start_dow) % 7 + return (n_timepoints + trim + 6) // 7 + + def sample( + self, + n_timepoints: int, + initial_value: float | ArrayLike | None = None, + n_processes: int = 1, + name_prefix: str = "stepwise", + *, + first_day_dow: int | None = None, + ) -> ArrayLike: + """ + Sample coarse trajectory from inner process and broadcast. + + Computes the number of coarse time steps needed for the requested + alignment, samples the inner process at that cadence, then broadcasts + each coarse value to the per-timepoint axis and trims to + ``n_timepoints``. The returned value always has one row per evaluation + timepoint, regardless of the inner parameter cadence. The coarse + trajectory is recorded as a NumPyro deterministic site named + ``"{name_prefix}_coarse"``. + + Parameters + ---------- + n_timepoints + Number of per-timepoint outputs to produce. + initial_value + Initial value(s) for the inner process. Defaults to 0.0. + n_processes + Number of parallel processes. + name_prefix + Prefix for numpyro sample sites; forwarded to the inner process. + first_day_dow + Day of week for element 0 of the shared model time axis + (0=Monday, ..., 6=Sunday). Required when + ``alignment="calendar_week"``. + + Returns + ------- + ArrayLike + Trajectories of shape ``(n_timepoints, n_processes)``, constant + within each block of ``step_size`` consecutive rows. + """ + n_steps = self._resolve_n_coarse(n_timepoints, first_day_dow) + # first_day_dow intentionally not forwarded: inner operates on the + # coarse axis; the outer's axis-origin day-of-week does not apply. + coarse = self.inner.sample( + n_timepoints=n_steps, + initial_value=initial_value, + n_processes=n_processes, + name_prefix=name_prefix, + ) + numpyro.deterministic(f"{name_prefix}_coarse", coarse) + if self.alignment == "model_index": + return jnp.repeat(coarse, repeats=self.step_size, axis=0)[:n_timepoints] + return weekly_to_daily( + coarse, + week_start_dow=self.week.start_dow, + output_data_first_dow=first_day_dow, + )[:n_timepoints] diff --git a/pyrenew/model/multisignal_model.py b/pyrenew/model/multisignal_model.py index 8a124348..7e9a099e 100644 --- a/pyrenew/model/multisignal_model.py +++ b/pyrenew/model/multisignal_model.py @@ -6,7 +6,10 @@ from __future__ import annotations +import datetime as dt + import jax.numpy as jnp +import numpy as np import numpyro import numpyro.handlers from jax.typing import ArrayLike @@ -14,6 +17,7 @@ from pyrenew.latent.base import BaseLatentInfectionProcess from pyrenew.metaclass import Model from pyrenew.observation.base import BaseObservationProcess +from pyrenew.time import convert_date class MultiSignalModel(Model): @@ -130,29 +134,77 @@ def pad_observations( padding = jnp.full(pad_shape, jnp.nan) return jnp.concatenate([padding, obs], axis=axis) - def compute_first_day_dow(self, obs_start_dow: int) -> int: + def _resolve_first_day_dow( + self, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None, + ) -> int | None: """ - Compute the day of the week for the start of the shared time axis. + Derive the axis-origin day-of-week from ``obs_start_date``. - The shared time axis begins ``n_init`` days before the first - observation. This method converts the known day of the week of - the first observation into the day of the week of the shared - time axis start (element 0), accounting for the initialization - period offset. + The shared daily time axis begins ``n_init`` days before the + first observation. This method converts a user-supplied + first-observation date into the day-of-week of element 0 of + that axis, accounting for the initialization-period offset. Parameters ---------- - obs_start_dow : int - Day of the week of the first observation day - (0=Monday, 6=Sunday, ISO convention). + obs_start_date + Date of the first observation day, or ``None``. Returns ------- - int - Day of the week for element 0 of the shared time axis. + int or None + Day-of-week index in ``{0, ..., 6}`` (0=Monday, ISO + convention) of element 0 of the shared axis. ``None`` when + ``obs_start_date`` is ``None``. """ + if obs_start_date is None: + return None n_init = self.latent.n_initialization_points - return (obs_start_dow - n_init) % 7 + return (convert_date(obs_start_date).weekday() - n_init) % 7 + + def _check_obs_start_date( + self, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None, + ) -> None: + """ + Check that ``obs_start_date`` is supplied whenever any component needs it. + + Observations with ``aggregation="weekly"`` or a + ``day_of_week_rv`` require a calendar anchor, as do latent + processes with a calendar-week-aligned temporal process. + Rather than surface a downstream error, raise at the model + entry naming the offending component. + + Parameters + ---------- + obs_start_date + Top-level calendar anchor passed by the caller. + + Raises + ------ + ValueError + If ``obs_start_date`` is ``None`` and any observation or + the latent process requires a calendar anchor. + """ + if obs_start_date is not None: + return + for name, obs in self.observations.items(): + if getattr(obs, "aggregation", "daily") == "weekly": + raise ValueError( + f"obs_start_date is required when any observation uses " + f"aggregation='weekly'; observation '{name}' does." + ) + if getattr(obs, "day_of_week_rv", None) is not None: + raise ValueError( + f"obs_start_date is required when any observation uses " + f"a day-of-week effect; observation '{name}' does." + ) + if self.latent.requires_calendar_anchor(): + raise ValueError( + "obs_start_date is required when the latent process uses a " + "temporal process with alignment='calendar_week'." + ) def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: """ @@ -179,7 +231,9 @@ def shift_times(self, times: jnp.ndarray) -> jnp.ndarray: def validate_data( self, n_days_post_init: int, + *, subpop_fractions: ArrayLike | None = None, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None, **observation_data: dict[str, object], ) -> None: """ @@ -197,9 +251,16 @@ def validate_data( Parameters ---------- n_days_post_init - Number of days to simulate after initialization period + Number of days to simulate after initialization period. subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). + obs_start_date + Date of the first observation day. Required when any + observation uses ``aggregation="weekly"`` or a day-of-week + effect, or when a latent temporal process uses + calendar-week alignment. Converted once to the axis-origin + ``first_day_dow`` and forwarded to the latent process and + every observation. **observation_data Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample(). @@ -207,16 +268,19 @@ def validate_data( Raises ------ ValueError - If times indices are out of bounds or negative - If dense obs length doesn't match n_total - If data shapes are inconsistent + If times indices are out of bounds or negative, if dense obs + length doesn't match n_total, if data shapes are inconsistent, + or if ``obs_start_date`` is missing when an observation requires it. """ + self._check_obs_start_date(obs_start_date) + pop = self.latent._parse_and_validate_fractions( subpop_fractions=subpop_fractions, ) n_init = self.latent.n_initialization_points n_total = n_init + n_days_post_init + first_day_dow = self._resolve_first_day_dow(obs_start_date) for name, obs_data in observation_data.items(): if name not in self.observations: @@ -225,9 +289,11 @@ def validate_data( f"Available: {list(self.observations.keys())}" ) - self.observations[name].validate_data( + obs = self.observations[name] + obs.validate_data( n_total=n_total, n_subpops=pop.n_subpops, + first_day_dow=first_day_dow, **obs_data, ) @@ -237,6 +303,7 @@ def sample( population_size: float, *, subpop_fractions: ArrayLike | None = None, + obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None, **observation_data: dict[str, object], ) -> None: """ @@ -253,6 +320,14 @@ def sample( (from latent process) to infection counts (for observation processes). subpop_fractions Population fractions for all subpopulations. Shape: (n_subpops,). + obs_start_date + Date of the first observation day. Converted once to the + axis-origin ``first_day_dow`` (day-of-week of element 0 of + the padded axis, after subtracting ``n_init``) and forwarded + to the latent process and every observation. Required when + any observation uses ``aggregation="weekly"`` or a + day-of-week effect, or when a latent temporal process uses + calendar-week alignment. **observation_data Data for each observation process, keyed by observation name (the ``name`` attribute of each observation process). @@ -266,10 +341,14 @@ def sample( observation sites. Use ``numpyro.infer.Predictive`` for forward sampling. """ + self._check_obs_start_date(obs_start_date) + first_day_dow = self._resolve_first_day_dow(obs_start_date) + # Generate latent infections (proportions) latent_sample = self.latent.sample( n_days_post_init=n_days_post_init, subpop_fractions=subpop_fractions, + first_day_dow=first_day_dow, ) # Scale from proportions to counts @@ -303,6 +382,7 @@ def sample( # Sample from observation process obs_process.sample( infections=latent_infections, + first_day_dow=first_day_dow, **obs_data, ) diff --git a/pyrenew/model/pyrenew_builder.py b/pyrenew/model/pyrenew_builder.py index c8efa4f8..02fc1e3f 100644 --- a/pyrenew/model/pyrenew_builder.py +++ b/pyrenew/model/pyrenew_builder.py @@ -203,7 +203,6 @@ def build(self) -> MultiSignalModel: 1. Computes n_initialization_points from all components 2. Constructs the latent process with the computed value 3. Creates a MultiSignalModel with automatic infection routing - 4. Validates that observation/latent types are compatible Can be called multiple times to create multiple model instances. @@ -215,7 +214,7 @@ def build(self) -> MultiSignalModel: Raises ------ ValueError - If latent process not configured + If latent process not configured. """ if self.latent_class is None: raise ValueError("Must call configure_latent() before build()") diff --git a/pyrenew/observation/base.py b/pyrenew/observation/base.py index 3ddcfb1a..1077cbb4 100644 --- a/pyrenew/observation/base.py +++ b/pyrenew/observation/base.py @@ -14,8 +14,10 @@ import numpyro from jax.typing import ArrayLike +from pyrenew.arrayutils import require_shape from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.metaclass import RandomVariable +from pyrenew.time import WeekCycle class BaseObservationProcess(RandomVariable): @@ -191,13 +193,79 @@ def _validate_dow_effect( ValueError If shape is not (7,) or any values are negative. """ - if dow_effect.shape != (7,): - raise ValueError( - f"{param_name} must return shape (7,), got {dow_effect.shape}" - ) + require_shape(dow_effect, (7,), param_name) if jnp.any(dow_effect < 0): raise ValueError(f"{param_name} must have non-negative values") + def _validate_week( + self, + aggregation: str, + week: WeekCycle | None, + ) -> None: + """ + Validate the ``(aggregation, week)`` pair. + + ``aggregation="weekly"`` requires a :class:`WeekCycle`; + ``aggregation="daily"`` ignores ``week``. + + Parameters + ---------- + aggregation + Observation reporting cadence; one of ``"daily"`` or + ``"weekly"``. + week + Calendar-week anchor; required iff + ``aggregation == "weekly"``. + + Raises + ------ + ValueError + If ``aggregation`` is unrecognized, or if + ``aggregation == "weekly"`` and ``week`` is ``None``. + """ + if aggregation not in ("daily", "weekly"): + raise ValueError( + f"aggregation must be one of {{'daily', 'weekly'}}, got {aggregation!r}" + ) + if aggregation == "weekly" and week is None: + raise ValueError("week is required when aggregation == 'weekly'") + + def _compute_period_offset( + self, + first_day_dow: int | None, + week: WeekCycle | None, + ) -> int: + """ + Compute the number of leading daily timepoints to trim so + that the daily axis aligns to whole weekly periods. + + Parameters + ---------- + first_day_dow + Day-of-week index of element 0 of the daily axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``week`` is not ``None``. + week + Calendar-week anchor. ``None`` indicates daily + (non-aggregated) observations; in that case the + offset is ``0``. + + Returns + ------- + int + Trim offset in ``[0, 7)``. Returns ``0`` when ``week`` is ``None``. + + Raises + ------ + ValueError + If ``week`` is provided but ``first_day_dow`` is ``None``. + """ + if week is None: + return 0 + if first_day_dow is None: + raise ValueError("first_day_dow is required when week is not None") + return (week.end_dow + 1 - first_day_dow) % 7 + def _convolve_with_alignment( self, latent_incidence: ArrayLike, @@ -360,9 +428,11 @@ def _validate_index_array( self, indices: ArrayLike, upper_bound: int, param_name: str ) -> None: """ - Validate an index array has non-negative values within bounds. + Validate an index array is 1D with non-negative values within bounds. - Checks that all values are non-negative integers in ``[0, upper_bound)``. + Checks that the array is 1D and that all values are non-negative + integers in ``[0, upper_bound)``. An empty 1D array is a no-op + and passes the bounds check. Parameters ---------- @@ -376,9 +446,17 @@ def _validate_index_array( Raises ------ ValueError - If indices contains negative values or values >= upper_bound. + If indices is not 1D, contains negative values, or values + >= upper_bound. """ indices = jnp.asarray(indices) + if indices.ndim != 1: + raise ValueError( + f"Observation '{self.name}': {param_name} must be 1D, " + f"got shape {indices.shape}" + ) + if indices.size == 0: + return if jnp.any(indices < 0): raise ValueError( f"Observation '{self.name}': {param_name} cannot be negative" @@ -432,28 +510,38 @@ def _validate_subpop_indices( """ self._validate_index_array(subpop_indices, n_subpops, "subpop_indices") - def _validate_obs_times_shape(self, obs: ArrayLike, times: ArrayLike) -> None: + def _validate_shapes_match( + self, + first: ArrayLike, + second: ArrayLike, + first_name: str, + second_name: str, + ) -> None: """ - Validate that obs and times arrays have matching shapes. + Validate that two arrays have matching shapes. Parameters ---------- - obs - Observed data array. - times - Times index array. + first + First array. + second + Second array. + first_name + Name of the first parameter (for error messages). + second_name + Name of the second parameter (for error messages). Raises ------ ValueError - If obs and times have different shapes. + If the two arrays have different shapes. """ - obs = jnp.asarray(obs) - times = jnp.asarray(times) - if obs.shape != times.shape: + first = jnp.asarray(first) + second = jnp.asarray(second) + if first.shape != second.shape: raise ValueError( - f"Observation '{self.name}': obs shape {obs.shape} " - f"must match times shape {times.shape}" + f"Observation '{self.name}': {first_name} shape {first.shape} " + f"must match {second_name} shape {second.shape}" ) def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: @@ -461,8 +549,9 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: Validate that obs covers the full shared time axis. For dense observations on the shared time axis ``[0, n_total)``, - obs must have length equal to ``n_total``. Use NaN to mark - unobserved timepoints (initialization period or missing data). + obs must be 1D with length equal to ``n_total``. Use NaN to + mark unobserved timepoints (initialization period or missing + data). Parameters ---------- @@ -474,9 +563,13 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: Raises ------ ValueError - If obs length doesn't equal n_total. + If obs is not 1D or its length doesn't equal n_total. """ obs = jnp.asarray(obs) + if obs.ndim != 1: + raise ValueError( + f"Observation '{self.name}': obs must be 1D, got shape {obs.shape}" + ) if obs.shape[0] != n_total: raise ValueError( f"Observation '{self.name}': obs length {obs.shape[0]} " @@ -484,6 +577,74 @@ def _validate_obs_dense(self, obs: ArrayLike, n_total: int) -> None: f"Pad with NaN for initialization period." ) + def _validate_period_end_times( + self, + period_end_times: ArrayLike, + n_total: int, + offset: int, + aggregation_period: int, + ) -> None: + """ + Validate a period-end-time index array. + + Checks that all values are non-negative, within + ``[0, n_total)``, are at least ``offset + aggregation_period - 1`` + (so the earliest period end points at the last day of the first + complete period, not inside the leading partial period), and lie + on aggregation-period boundaries, i.e., ``(t - offset) % + aggregation_period == aggregation_period - 1`` for every entry + ``t``. When ``aggregation_period == 1`` the alignment condition + holds trivially and only the bounds check runs. + + Parameters + ---------- + period_end_times + Daily-axis indices of each observed period's final day. + n_total + Total number of time steps. + offset + Front-trim offset in daily units, as returned by + ``_compute_period_offset``. Must be in + ``[0, aggregation_period)``. + aggregation_period + Width of the reporting period in fundamental time units. + + Raises + ------ + ValueError + If ``period_end_times`` contains negative values, values + ``>= n_total``, values below the first complete period's + final day, or entries that fail the alignment check. + """ + self._validate_index_array(period_end_times, n_total, "period_end_times") + if aggregation_period == 1: + return + period_end_times = jnp.asarray(period_end_times) + + # Lower bound: the first complete period's final day. An entry below + # this would yield a negative period index under fancy-indexing + # (which JAX silently wraps to the last element of the aggregated + # array). + min_valid = offset + aggregation_period - 1 + if jnp.any(period_end_times < min_valid): + raise ValueError( + f"Observation '{self.name}': period_end_times must be >= " + f"{min_valid} (= offset + aggregation_period - 1); entries " + f"below this do not correspond to a complete aggregation period." + ) + + misaligned = (period_end_times - offset) % aggregation_period != ( + aggregation_period - 1 + ) + if jnp.any(misaligned): + raise ValueError( + f"Observation '{self.name}': period_end_times must lie on " + f"aggregation-period boundaries " + f"(offset={offset}, aggregation_period={aggregation_period}); " + f"each entry t must satisfy " + f"(t - offset) % {aggregation_period} == {aggregation_period - 1}." + ) + @abstractmethod def sample( self, diff --git a/pyrenew/observation/count_observations.py b/pyrenew/observation/count_observations.py index 563285c2..24ecd6f2 100644 --- a/pyrenew/observation/count_observations.py +++ b/pyrenew/observation/count_observations.py @@ -5,6 +5,8 @@ from __future__ import annotations +from typing import Literal + import jax import jax.numpy as jnp from jax.typing import ArrayLike @@ -14,7 +16,11 @@ from pyrenew.observation.base import BaseObservationProcess from pyrenew.observation.noise import CountNoise from pyrenew.observation.types import ObservationSample -from pyrenew.time import get_sequential_day_of_week_indices +from pyrenew.time import ( + WeekCycle, + daily_to_weekly, + get_sequential_day_of_week_indices, +) class CountObservation(BaseObservationProcess): @@ -22,9 +28,14 @@ class CountObservation(BaseObservationProcess): Abstract Base class for count observation processes. Subclasses map infections to counts through ascertainment x delay convolution - with composable noise model. + with composable noise model. Count observations always receive predictions + on the model's daily time axis and then, if requested, aggregate those + daily predictions to the observation reporting grid before evaluating the + likelihood. """ + _SUPPORTED_SCHEDULES = ("regular", "irregular") + def __init__( self, name: str, @@ -33,6 +44,9 @@ def __init__( noise: CountNoise, right_truncation_rv: RandomVariable | None = None, day_of_week_rv: RandomVariable | None = None, + aggregation: Literal["daily", "weekly"] = "daily", + reporting_schedule: Literal["regular", "irregular"] = "regular", + week: WeekCycle | None = None, ) -> None: """ Initialize count observation base. @@ -63,12 +77,63 @@ def __init__( overall predicted counts. When provided (along with ``first_day_dow`` at sample time), predicted counts are scaled by a periodic weekly pattern. + aggregation + Observation reporting cadence; one of ``"daily"`` or + ``"weekly"``. Controls only the scale on which the count + likelihood is evaluated; it does not control how often the + latent Rt temporal process samples new parameters. + reporting_schedule + Either ``"regular"`` (dense observation array, one entry + per period, NaN for unobserved periods) or ``"irregular"`` + (sparse observation array with user-supplied period-end + time indices). + week + Calendar-week anchor used for weekly aggregation (e.g., + :data:`pyrenew.time.MMWR_WEEK` for Sunday-Saturday + epiweeks, :data:`pyrenew.time.ISO_WEEK` for Monday-Sunday + weeks). Required when ``aggregation == "weekly"``; ignored + otherwise. Daily predictions are bucketed into weeks + according to ``week`` and summed before scoring. + + Raises + ------ + ValueError + If ``aggregation``, ``reporting_schedule``, or ``week`` are + invalid, or if a day-of-week effect is combined with + ``aggregation == "weekly"`` (within-period structure is + destroyed by aggregation). """ super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv) self.ascertainment_rate_rv = ascertainment_rate_rv self.noise = noise self.right_truncation_rv = right_truncation_rv self.day_of_week_rv = day_of_week_rv + self._validate_week(aggregation, week) + if reporting_schedule not in self._SUPPORTED_SCHEDULES: + raise ValueError( + f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, " + f"got {reporting_schedule!r}" + ) + if aggregation == "weekly" and day_of_week_rv is not None: + raise ValueError( + "day_of_week_rv cannot be combined with aggregation == 'weekly'; " + "aggregation destroys within-period structure." + ) + self.aggregation = aggregation + self.reporting_schedule = reporting_schedule + self.week = week + + @property + def aggregation_period(self) -> int: + """ + Width of the observation reporting period in days. + + Returns + ------- + int + ``1`` for daily aggregation, ``7`` for weekly. + """ + return 7 if self.aggregation == "weekly" else 1 def validate(self) -> None: """ @@ -244,13 +309,231 @@ def _apply_day_of_week( daily_effect = daily_effect[:, None] return predicted * daily_effect + def _aggregate( + self, + predicted_daily: ArrayLike, + first_day_dow: int | None, + ) -> ArrayLike: + """ + Aggregate daily predicted counts to the observation reporting grid. + + When ``aggregation == "daily"`` returns the input unchanged. + Otherwise sums daily values over non-overlapping fixed-width + periods anchored by ``week``, via + ``pyrenew.time.daily_to_weekly``. Works on both 1D + ``(n_total,)`` and 2D ``(n_total, n_subpops)`` inputs. + This aggregation is part of the observation likelihood path and is + independent of the parameter cadence used by the latent Rt process. + + Parameters + ---------- + predicted_daily + Predicted counts on the daily time axis. + first_day_dow + Day-of-week index of element 0 of the daily axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation == "weekly"``. + + Returns + ------- + ArrayLike + Aggregated counts on the period grid; same trailing + dimensions as ``predicted_daily``. Returns + ``predicted_daily`` unchanged when + ``aggregation == "daily"``. + + Raises + ------ + ValueError + If ``aggregation == "weekly"`` and ``first_day_dow`` is ``None``. + """ + if self.aggregation == "daily": + return predicted_daily + if first_day_dow is None: + raise ValueError("first_day_dow is required when aggregation == 'weekly'") + return daily_to_weekly( + predicted_daily, + input_data_first_dow=first_day_dow, + week_start_dow=self.week.start_dow, + ) + + def _compute_predicted( + self, + infections: ArrayLike, + first_day_dow: int | None, + right_truncation_offset: int | None, + ) -> ArrayLike: + """ + Build the predicted counts on the reporting-period grid. + + Runs ascertainment and delay convolution, then optionally + applies the day-of-week multiplicative effect and + right-truncation adjustment, and aggregates to the reporting + grid. Emits ``predicted`` (and ``predicted_daily`` when + aggregating) as numpyro deterministic sites. + + Parameters + ---------- + infections + Infections from the latent process. Shape ``(n_total,)`` + for aggregate, ``(n_total, n_subpops)`` for subpopulation-level. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. + right_truncation_offset + If set (together with ``right_truncation_rv``), the + number of additional reporting days that have occurred + since the last observation. + + Returns + ------- + ArrayLike + Predicted counts on the reporting-period grid; same + trailing dimensions as ``infections``. Equal to + predicted-daily when ``aggregation == "daily"``. + + Raises + ------ + ValueError + If ``day_of_week_rv`` was set but ``first_day_dow`` is + ``None``. + """ + predicted_daily = self._predicted_obs(infections) + if self.day_of_week_rv is not None: + if first_day_dow is None: + raise ValueError( + "first_day_dow is required when day_of_week_rv is set." + ) + predicted_daily = self._apply_day_of_week(predicted_daily, first_day_dow) + if self.right_truncation_rv is not None and right_truncation_offset is not None: + predicted_daily = self._apply_right_truncation( + predicted_daily, right_truncation_offset + ) + predicted = self._aggregate(predicted_daily, first_day_dow) + if self.aggregation == "weekly": + self._deterministic("predicted_daily", predicted_daily) + self._deterministic("predicted", predicted) + return predicted + + def _score_masked( + self, + predicted: ArrayLike, + obs: ArrayLike | None, + ) -> ArrayLike: + """ + Evaluate the masked likelihood on a dense period grid. + + Builds a boolean mask from the non-NaN positions of + ``predicted`` and (if provided) ``obs``, replaces NaN entries + with safe placeholder values, and delegates to + ``noise.sample`` with the mask. Shape-agnostic: works on 1D + period grids and 2D ``(n_periods, n_subpops)`` grids. + + Parameters + ---------- + predicted + Predicted counts on the reporting-period grid. NaN + entries mark the initialization period. + obs + Observed counts of the same shape as ``predicted``, or + ``None`` for prior predictive sampling. NaN entries mark + unobserved periods. + + Returns + ------- + ArrayLike + Sampled or conditioned counts from the noise model. + + Notes + ----- + JAX evaluates ``log_prob`` for every element regardless of + the mask; replacing NaN with finite placeholders prevents + NaN propagation in the trace while ``mask=False`` excludes + those entries from the likelihood sum. + """ + valid_pred = ~jnp.isnan(predicted) + if obs is not None: + valid_obs = ~jnp.isnan(obs) + mask = valid_pred & valid_obs + else: + mask = valid_pred + safe_predicted = jnp.where(jnp.isnan(predicted), 1.0, predicted) + safe_obs = None + if obs is not None: + safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) + return self.noise.sample( + name=self._sample_site_name("obs"), + predicted=safe_predicted, + obs=safe_obs, + mask=mask, + ) + + def _period_indices( + self, + period_end_times: ArrayLike, + first_day_dow: int | None, + ) -> jnp.ndarray: + """ + Convert daily-axis period-end indices to period-grid indices. + + For each daily-axis index ``t`` identifying the final day of + a reporting period, returns the position of that period in + the aggregated output. + + Parameters + ---------- + period_end_times + Daily-axis indices of each observed period's final day. + first_day_dow + Day-of-week index of element 0 of the daily axis. Required + when ``aggregation == "weekly"``. + + Returns + ------- + jnp.ndarray + Period-grid indices, one per entry in ``period_end_times``. + """ + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, self.week) + return (jnp.asarray(period_end_times) - offset - (P - 1)) // P + + def _n_periods(self, n_total: int, first_day_dow: int | None) -> int: + """ + Return the number of complete reporting periods in ``n_total`` days. + + Parameters + ---------- + n_total + Total number of daily time steps (``n_init + n_days_post_init``). + first_day_dow + Day-of-week index of element 0 of the daily axis. Required + when ``aggregation == "weekly"``. + + Returns + ------- + int + ``n_total`` when ``aggregation == "daily"``; otherwise the + number of complete weekly periods after trimming the + leading partial week. + """ + if self.aggregation == "daily": + return n_total + P = self.aggregation_period + offset = self._compute_period_offset(first_day_dow, self.week) + return (n_total - offset) // P + class PopulationCounts(CountObservation): """ Aggregated count observation. Maps aggregate infections to counts through ascertainment x delay - convolution with composable noise model. + convolution with composable noise model. Predictions are constructed on + the daily model axis; ``aggregation_period`` controls whether those + predictions are scored as daily counts or summed to weekly counts before + the likelihood. Parameters ---------- @@ -293,6 +576,8 @@ def validate_data( n_total: int, n_subpops: int, obs: ArrayLike | None = None, + period_end_times: ArrayLike | None = None, + first_day_dow: int | None = None, **kwargs: object, ) -> None: """ @@ -301,21 +586,67 @@ def validate_data( Parameters ---------- n_total - Total number of time steps (n_init + n_days_post_init). + Total number of daily time steps (``n_init + n_days_post_init``). n_subpops Number of subpopulations (unused for aggregate observations). obs - Observed counts on shared time axis. Shape: (n_total,). + Observed counts. Shape depends on ``reporting_schedule``: + ``"regular"`` expects a dense array of length ``n_total // P`` + after front-trim, with NaN for unobserved periods; + ``"irregular"`` expects an array matching ``period_end_times``. + period_end_times + Daily-axis indices of each observed period's final day. Required + for ``reporting_schedule="irregular"``. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation == "weekly"`` so weekly observation periods can be + aligned to the shared daily model axis. **kwargs Additional keyword arguments (ignored). Raises ------ ValueError - If obs length doesn't match n_total. + If obs length or period_end_times fail their respective + checks, or if ``first_day_dow`` is missing when + ``aggregation == "weekly"``. """ + if self.reporting_schedule == "regular": + if obs is None: + return + if self.aggregation == "daily": + self._validate_obs_dense(obs, n_total) + return + n_periods = self._n_periods(n_total, first_day_dow) + obs = jnp.asarray(obs) + if obs.ndim != 1: + raise ValueError( + f"Observation '{self.name}': obs must be 1D, got shape {obs.shape}" + ) + if obs.shape[0] != n_periods: + raise ValueError( + f"Observation '{self.name}': obs length {obs.shape[0]} " + f"must equal n_periods ({n_periods}). " + f"Pad with NaN for unobserved periods." + ) + return + + if period_end_times is None: + if obs is None: + return + raise ValueError( + f"Observation '{self.name}': period_end_times is required " + f"when reporting_schedule='irregular' and obs is provided." + ) + offset = self._compute_period_offset(first_day_dow, self.week) + self._validate_period_end_times( + period_end_times, n_total, offset, self.aggregation_period + ) if obs is not None: - self._validate_obs_dense(obs, n_total) + self._validate_shapes_match( + obs, period_end_times, "obs", "period_end_times" + ) def sample( self, @@ -323,79 +654,79 @@ def sample( obs: ArrayLike | None = None, right_truncation_offset: int | None = None, first_day_dow: int | None = None, + period_end_times: ArrayLike | None = None, ) -> ObservationSample: """ Sample aggregated counts. - Both infections and obs use a shared time axis [0, n_total) where - n_total = n_init + n_days. NaN in obs marks unobserved timepoints - (initialization period or missing data). + Daily transforms (right-truncation, day-of-week) run on the + daily axis. When ``aggregation == "weekly"`` the daily + predictions are summed onto the reporting-period grid before + the noise model. Likelihood path depends on + ``reporting_schedule``: ``"regular"`` uses a dense-with-NaN + array plus a mask; ``"irregular"`` fancy-indexes the + aggregated array at period indices derived from + ``period_end_times``. + + ``aggregation_period`` describes the observation scale only. The + latent infection process may use daily or coarser Rt parameter + cadence, but by the time this method is called it supplies infections + on the daily model axis. Parameters ---------- infections Aggregate infections from the infection process. - Shape: (n_total,) where n_total = n_init + n_days. + Shape ``(n_total,)``. obs - Observed counts on shared time axis. Shape: (n_total,). - Use NaN for initialization period and any missing observations. - None for prior predictive sampling. + Observed counts. Shape depends on ``reporting_schedule``: + ``"regular"`` expects a dense array on the period grid + with NaN for unobserved periods; ``"irregular"`` expects + an array of the same length as ``period_end_times``. + ``None`` for prior predictive sampling. right_truncation_offset - If provided (and ``right_truncation_rv`` was set at construction), - apply right-truncation adjustment to predicted counts. - first_day_dow : int | None - Day of the week for the first timepoint on the shared time - axis (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Use - ``model.compute_first_day_dow(obs_start_dow)`` to convert - from the day of the week of the first observation. + If provided (and ``right_truncation_rv`` was set at + construction), apply right-truncation adjustment to the + daily predictions. + first_day_dow + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. + period_end_times + Daily-axis indices of each observed period's final day. + Required when ``reporting_schedule == "irregular"``. Returns ------- ObservationSample - Named tuple with `observed` (sampled/conditioned counts) and - `predicted` (predicted counts before noise, shape: n_total). + Named tuple with ``observed`` (sampled/conditioned counts) + and ``predicted`` (predictions on the reporting-period + grid; equal to daily predictions when + ``aggregation == "daily"``). """ - predicted_counts = self._predicted_obs(infections) - if self.day_of_week_rv is not None: - if first_day_dow is None: + predicted = self._compute_predicted( + infections, first_day_dow, right_truncation_offset + ) + + if self.reporting_schedule == "regular": + observed = self._score_masked(predicted, obs) + else: + if period_end_times is None: raise ValueError( - "first_day_dow is required when day_of_week_rv is set." + f"Observation '{self.name}': period_end_times is " + f"required when reporting_schedule == 'irregular'" ) - predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) - if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_counts = self._apply_right_truncation( - predicted_counts, right_truncation_offset + period_idx = self._period_indices(period_end_times, first_day_dow) + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted[period_idx], + obs=obs, ) - self._deterministic("predicted", predicted_counts) - # Compute mask: True where observation contributes to likelihood. - # NaN in predictions (initialization period) or obs (missing data) - # are excluded via mask. - valid_pred = ~jnp.isnan(predicted_counts) - if obs is not None: - valid_obs = ~jnp.isnan(obs) - mask = valid_pred & valid_obs - else: - mask = valid_pred - - # JAX evaluates log_prob for all array elements even when mask - # excludes them from the likelihood sum. Replace NaN with safe values - # to avoid NaN propagation in JAX's computation graph. These values - # do not affect inference since mask=False excludes them. - safe_predicted = jnp.where(jnp.isnan(predicted_counts), 1.0, predicted_counts) - safe_obs = None - if obs is not None: - safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs) - - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=safe_predicted, - obs=safe_obs, - mask=mask, - ) - - return ObservationSample(observed=observed, predicted=predicted_counts) + return ObservationSample(observed=observed, predicted=predicted) class SubpopulationCounts(CountObservation): @@ -404,6 +735,10 @@ class SubpopulationCounts(CountObservation): Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model. + Predictions are constructed on the daily model axis for each + subpopulation; ``aggregation_period`` controls whether those predictions + are scored as daily counts or summed to weekly counts before the + likelihood. Parameters ---------- @@ -444,9 +779,10 @@ def validate_data( self, n_total: int, n_subpops: int, - times: ArrayLike | None = None, - subpop_indices: ArrayLike | None = None, obs: ArrayLike | None = None, + period_end_times: ArrayLike | None = None, + first_day_dow: int | None = None, + subpop_indices: ArrayLike | None = None, **kwargs: object, ) -> None: """ @@ -455,96 +791,180 @@ def validate_data( Parameters ---------- n_total - Total number of time steps (n_init + n_days_post_init). + Total number of daily time steps (``n_init + n_days_post_init``). n_subpops Number of subpopulations. - times - Day index for each observation on the shared time axis. - subpop_indices - Subpopulation index for each observation (0-indexed). obs - Observed counts (n_obs,). + Observed counts. For ``reporting_schedule="regular"`` + has shape ``(n_periods, n_observed_subpops)`` with NaN + for unobserved periods. For + ``reporting_schedule="irregular"`` has shape ``(n_obs,)`` + matching ``period_end_times`` and ``subpop_indices``. + period_end_times + Daily-axis indices of each observed period's final day. + Required for ``reporting_schedule="irregular"``. + first_day_dow + Day-of-week index of element 0 of the shared time axis + (0=Monday, 6=Sunday, ISO convention). Required when + ``aggregation == "weekly"`` so weekly observation periods can be + aligned to the shared daily model axis. + subpop_indices + Subpopulation indices (0-indexed). For + ``reporting_schedule="regular"``: shape + ``(n_observed_subpops,)`` selecting which subpopulation + columns appear in ``obs``. For + ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + with one subpopulation per observation. **kwargs Additional keyword arguments (ignored). Raises ------ ValueError - If times or subpop_indices are out of bounds, or if - obs and times have mismatched lengths. + If any index array is out of bounds, any shape check + fails, or ``first_day_dow`` is missing when + ``aggregation == "weekly"``. """ - if times is not None: - self._validate_times(times, n_total) - if obs is not None: - self._validate_obs_times_shape(obs, times) if subpop_indices is not None: self._validate_subpop_indices(subpop_indices, n_subpops) + if self.reporting_schedule == "regular": + if obs is None: + return + n_periods = self._n_periods(n_total, first_day_dow) + obs = jnp.asarray(obs) + if obs.ndim != 2: + raise ValueError( + f"Observation '{self.name}': regular-schedule obs must " + f"be 2D (n_periods, n_observed_subpops); got shape {obs.shape}" + ) + if obs.shape[0] != n_periods: + raise ValueError( + f"Observation '{self.name}': obs dimension 0 length " + f"{obs.shape[0]} must equal n_periods ({n_periods}). " + f"Pad with NaN for unobserved periods." + ) + if subpop_indices is not None: + n_observed = jnp.asarray(subpop_indices).shape[0] + if obs.shape[1] != n_observed: + raise ValueError( + f"Observation '{self.name}': obs dimension 1 length " + f"{obs.shape[1]} must equal len(subpop_indices) " + f"({n_observed})" + ) + return + + if period_end_times is None: + if obs is None: + return + raise ValueError( + f"Observation '{self.name}': period_end_times is required " + f"when reporting_schedule='irregular' and obs is provided." + ) + offset = self._compute_period_offset(first_day_dow, self.week) + self._validate_period_end_times( + period_end_times, n_total, offset, self.aggregation_period + ) + if obs is not None: + self._validate_shapes_match( + obs, period_end_times, "obs", "period_end_times" + ) + if subpop_indices is not None: + self._validate_shapes_match( + subpop_indices, + period_end_times, + "subpop_indices", + "period_end_times", + ) + def sample( self, infections: ArrayLike, - times: ArrayLike, - subpop_indices: ArrayLike, obs: ArrayLike | None = None, right_truncation_offset: int | None = None, first_day_dow: int | None = None, + period_end_times: ArrayLike | None = None, + subpop_indices: ArrayLike | None = None, ) -> ObservationSample: """ Sample subpopulation-level counts. - Times are on the shared time axis [0, n_total) where - n_total = n_init + n_days. This method performs direct indexing - without any offset adjustment. + Daily transforms (right-truncation, day-of-week) run on the + daily axis. When ``aggregation == "weekly"`` the daily + predictions are summed onto the reporting-period grid before + the noise model. Likelihood path depends on + ``reporting_schedule``: ``"regular"`` selects the observed + subpopulation columns and uses a dense-with-NaN array plus + a mask; ``"irregular"`` fancy-indexes the aggregated array + at period indices derived from ``period_end_times``. + + ``aggregation_period`` describes the observation scale only. The + latent infection process may use daily or coarser Rt parameter + cadence, but by the time this method is called it supplies infections + on the daily model axis. Parameters ---------- infections Subpopulation-level infections from the infection process. - Shape: (n_total, n_subpops) - times - Day index for each observation on the shared time axis. - Must be in range [0, n_total). Shape: (n_obs,) - subpop_indices - Subpopulation index for each observation (0-indexed). - Shape: (n_obs,) + Shape ``(n_total, n_subpops)``. obs - Observed counts (n_obs,), or None for prior sampling. + Observed counts. For ``reporting_schedule="regular"``: + shape ``(n_periods, n_observed_subpops)`` with NaN for + unobserved periods. For + ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + matching ``period_end_times`` and ``subpop_indices``. + ``None`` for prior predictive sampling. right_truncation_offset - If provided (and ``right_truncation_rv`` was set at construction), - apply right-truncation adjustment to predicted counts. - first_day_dow : int | None - Day of the week for the first timepoint on the shared time - axis (0=Monday, 6=Sunday, ISO convention). Required when - ``day_of_week_rv`` was set at construction. Use - ``model.compute_first_day_dow(obs_start_dow)`` to convert - from the day of the week of the first observation. + If provided (and ``right_truncation_rv`` was set at + construction), apply right-truncation adjustment to the + daily predictions. + first_day_dow + Day-of-week index of the first timepoint on the shared + time axis (0=Monday, 6=Sunday, ISO convention). Required + when ``day_of_week_rv`` was set at construction or when + ``aggregation == "weekly"``. This aligns observation-level + day-of-week effects or weekly aggregation to the shared daily + model axis. + period_end_times + Daily-axis indices of each observed period's final day. + Required when ``reporting_schedule == "irregular"``. + subpop_indices + Subpopulation indices (0-indexed). Required. For + ``reporting_schedule="regular"``: shape + ``(n_observed_subpops,)`` selecting which subpopulation + columns of the aggregated array enter the likelihood. + For ``reporting_schedule="irregular"``: shape ``(n_obs,)`` + with one subpopulation per observation. Returns ------- ObservationSample - Named tuple with `observed` (sampled/conditioned counts) and - `predicted` (predicted counts before noise, shape: n_total x n_subpops). + Named tuple with ``observed`` (sampled/conditioned counts) + and ``predicted`` (predictions on the reporting-period + grid, shape ``(n_periods, n_subpops)``; equal to daily + predictions when ``aggregation == "daily"``). """ - predicted_counts = self._predicted_obs(infections) - if self.day_of_week_rv is not None: - if first_day_dow is None: + if subpop_indices is None: + raise ValueError(f"Observation '{self.name}': subpop_indices is required.") + + predicted = self._compute_predicted( + infections, first_day_dow, right_truncation_offset + ) + + if self.reporting_schedule == "regular": + observed = self._score_masked(predicted[:, subpop_indices], obs) + else: + if period_end_times is None: raise ValueError( - "first_day_dow is required when day_of_week_rv is set." + f"Observation '{self.name}': period_end_times is " + f"required when reporting_schedule == 'irregular'" ) - predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow) - if self.right_truncation_rv is not None and right_truncation_offset is not None: - predicted_counts = self._apply_right_truncation( - predicted_counts, right_truncation_offset + period_idx = self._period_indices(period_end_times, first_day_dow) + observed = self.noise.sample( + name=self._sample_site_name("obs"), + predicted=predicted[period_idx, subpop_indices], + obs=obs, ) - self._deterministic("predicted", predicted_counts) - - # Direct indexing on shared time axis - no offset needed - predicted_obs = predicted_counts[times, subpop_indices] - - observed = self.noise.sample( - name=self._sample_site_name("obs"), - predicted=predicted_obs, - obs=obs, - ) - return ObservationSample(observed=observed, predicted=predicted_counts) + return ObservationSample(observed=observed, predicted=predicted) diff --git a/pyrenew/observation/measurement_observations.py b/pyrenew/observation/measurement_observations.py index e5c951a7..03fe13e2 100644 --- a/pyrenew/observation/measurement_observations.py +++ b/pyrenew/observation/measurement_observations.py @@ -146,7 +146,7 @@ def validate_data( if times is not None: self._validate_times(times, n_total) if obs is not None: - self._validate_obs_times_shape(obs, times) + self._validate_shapes_match(obs, times, "obs", "times") if subpop_indices is not None: self._validate_subpop_indices(subpop_indices, n_subpops) if sensor_indices is not None and n_sensors is not None: @@ -160,6 +160,7 @@ def sample( sensor_indices: ArrayLike, n_sensors: int, obs: ArrayLike | None = None, + **kwargs: object, ) -> ObservationSample: """ Sample measurements from observed sensors. @@ -189,6 +190,10 @@ def sample( Total number of measurement sensors. obs Observed measurements (n_obs,), or None for prior sampling. + **kwargs + Additional keyword arguments forwarded by the model + dispatch (e.g., ``first_day_dow``); ignored here because + measurement observations index the shared axis directly. Returns ------- diff --git a/pyrenew/time.py b/pyrenew/time.py index 8d9357c3..309a9805 100644 --- a/pyrenew/time.py +++ b/pyrenew/time.py @@ -1,11 +1,50 @@ """ -Helper functions for handling timeseries in Pyrenew - -Days of the week in pyrenew are 0-indexed and follow -ISO standards, so 0 is Monday at 6 is Sunday. +Helper functions for handling timeseries in Pyrenew. + +Days of the week in pyrenew are 0-indexed and follow ISO standards, +so 0 is Monday and 6 is Sunday. + +Calendar concepts used elsewhere in pyrenew +------------------------------------------- + +Two calendar concepts appear across the codebase. They name distinct +things; do not conflate them. + +- ``first_day_dow`` — day-of-week of element 0 of a model's shared + daily time axis. Sample-time data fact (depends on the dataset + being fit). Threaded through ``MultiSignalModel.sample`` and the + ``TemporalProcess.sample`` protocol. + +- :class:`WeekCycle` — a 7-day calendar cycle identified by its + ``start_dow``. Construction-time modeling choice, set independently on + each model component that does calendar-week work (weekly + :class:`CountObservation` and calendar-week-aligned + :class:`StepwiseTemporalProcess`). Different components can use + different cycles; each performs its own trimming and aggregation + relative to the shared daily axis. The module-level constants + :data:`MMWR_WEEK` (Sun-Sat) and :data:`ISO_WEEK` (Mon-Sun) cover + the common conventions; custom cycles use ``WeekCycle(start_dow=k)``. + +Worked example: ``first_day_dow=3`` (Thursday), ``week=MMWR_WEEK`` +(Sunday start), 17-day axis:: + + model index: 0 1 2 | 3 4 5 6 7 8 9 | 10 11 12 13 14 15 16 + weekday: Th Fr Sa | Su Mo Tu We Th Fr Sa | Su Mo Tu We Th Fr Sa + weekly Rt: [ c0 ] [ c1 ] [ c2 ] + +The first three indices fall in the leading partial week (``c0``); +each subsequent block of 7 days falls in one full Sunday-Saturday +week and shares one Rt sample. + +The low-level functions :func:`daily_to_weekly` and +:func:`weekly_to_daily` take integer day-of-week parameters +(``input_data_first_dow``, ``output_data_first_dow``, +``week_start_dow``) directly so that callers outside pyrenew can use +them without the :class:`WeekCycle` abstraction. """ import datetime as dt +from dataclasses import dataclass import jax.numpy as jnp import numpy as np @@ -53,6 +92,46 @@ def validate_dow(day_of_week: int, variable_name: str) -> None: return None +@dataclass(frozen=True) +class WeekCycle: + """ + A 7-day calendar cycle identified by its start day. + + Value object that names a single weekly convention (MMWR, + ISO, or custom) and is shared by any model component that + must agree on it. Equality is structural, so a builder-level + coherence check reduces to ``a == b``. + + Attributes + ---------- + start_dow + Day-of-week on which the cycle begins + (0=Monday, 6=Sunday, ISO convention). + """ + + start_dow: int + + def __post_init__(self) -> None: + """Validate ``start_dow`` via :func:`validate_dow`.""" + validate_dow(self.start_dow, "start_dow") + + @property + def end_dow(self) -> int: + """ + Day-of-week on which the cycle ends. + + Returns + ------- + int + ``(start_dow + 6) % 7``. + """ + return (self.start_dow + 6) % 7 + + +MMWR_WEEK: WeekCycle = WeekCycle(start_dow=6) +ISO_WEEK: WeekCycle = WeekCycle(start_dow=0) + + def get_sequential_day_of_week_indices( first_day_dow: int, n_timepoints: int ) -> jnp.ndarray: diff --git a/test/conftest.py b/test/conftest.py index c6dfb4c3..e9100532 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,8 +3,22 @@ This module provides reusable fixtures for creating observation processes, test data, and common configurations used across multiple test files. + +The module also sets two JAX environment variables before any jax import so +that all tests (unit and integration) run with 64-bit precision and with +four logical host devices available for parallel MCMC chains. JAX reads +these variables at import time, so they must be set before the first +``import jax`` anywhere in the test process; placing them at the top of +this file (loaded by pytest before any test module) satisfies that +requirement. The ``setdefault`` form respects any value the caller already +set at the shell level (e.g., a CI with different configuration). """ +import os + +os.environ.setdefault("JAX_ENABLE_X64", "true") +os.environ.setdefault("XLA_FLAGS", "--xla_force_host_platform_device_count=4") + import jax.numpy as jnp import numpyro.distributions as dist import pytest @@ -19,9 +33,12 @@ from pyrenew.observation import ( HierarchicalNormalNoise, NegativeBinomialNoise, + PoissonNoise, PopulationCounts, + SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable +from pyrenew.time import MMWR_WEEK # ============================================================================= # PMF Fixtures @@ -260,6 +277,266 @@ def counts_factory(): return CountsProcessFactory() +@pytest.fixture +def weekly_regular_counts(simple_delay_pmf): + """ + PopulationCounts with weekly aggregation and regular (dense) reporting. + + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). + + Returns + ------- + PopulationCounts + A weekly-regular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="regular", + week=MMWR_WEEK, + ) + + +@pytest.fixture +def weekly_irregular_counts(simple_delay_pmf): + """ + PopulationCounts with weekly aggregation and irregular (sparse) reporting. + + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). + + Returns + ------- + PopulationCounts + A weekly-irregular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="irregular", + week=MMWR_WEEK, + ) + + +@pytest.fixture +def daily_irregular_counts(simple_delay_pmf): + """ + PopulationCounts with daily scale and irregular (sparse) reporting. + + Returns + ------- + PopulationCounts + A daily-irregular PopulationCounts process. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + reporting_schedule="irregular", + ) + + +@pytest.fixture +def weekly_regular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with weekly aggregation and regular (dense) reporting. + + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). + + Returns + ------- + SubpopulationCounts + A weekly-regular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="regular", + week=MMWR_WEEK, + ) + + +@pytest.fixture +def weekly_irregular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with weekly aggregation and irregular (sparse) reporting. + + Uses ``MMWR_WEEK`` (Sunday-Saturday epiweeks). + + Returns + ------- + SubpopulationCounts + A weekly-irregular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="irregular", + week=MMWR_WEEK, + ) + + +@pytest.fixture +def daily_irregular_subpop_counts(simple_delay_pmf): + """ + SubpopulationCounts with daily scale and irregular (sparse) reporting. + + Returns + ------- + SubpopulationCounts + A daily-irregular SubpopulationCounts process. + """ + return SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + reporting_schedule="irregular", + ) + + # ============================================================================= # Infection Fixtures # ============================================================================= + + +@pytest.fixture +def subpop_infections_28d(): + """ + Four-week ``(28, 3)`` infections array at 100/day/subpop. + + Used by weekly-aggregation tests on ``SubpopulationCounts``. + + Returns + ------- + jnp.ndarray + Shape ``(28, 3)`` filled with 100.0. + """ + return jnp.ones((28, 3)) * 100.0 + + +@pytest.fixture +def subpop_infections_30d(): + """ + Thirty-day ``(30, 3)`` infections array at 100/day/subpop. + + Used by daily-scale tests on ``SubpopulationCounts``. + + Returns + ------- + jnp.ndarray + Shape ``(30, 3)`` filled with 100.0. + """ + return jnp.ones((30, 3)) * 100.0 + + +@pytest.fixture +def mmwr_saturday_indices_first_three(): + """ + Daily-axis indices of the first three MMWR Saturdays for a Sunday-start axis. + + When element 0 of the daily axis is a Sunday (``first_day_dow=6``), + Saturdays fall on indices 6, 13, 20, ... + + Returns + ------- + jnp.ndarray + Shape ``(3,)`` containing ``[6, 13, 20]``. + """ + return jnp.array([6, 13, 20]) + + +# ============================================================================= +# Temporal Process Stubs +# ============================================================================= + + +class WrongShapeTemporalProcess: + """Temporal process stub that returns a fixed wrong-shaped array.""" + + step_size = 1 + + def __init__(self, value): + """ + Store the wrong-shaped value to return from ``sample``. + + Parameters + ---------- + value + Array returned by ``sample`` regardless of requested shape. + """ + self.value = value + + def sample(self, **kwargs): + """ + Return the configured wrong-shaped value. + + Returns + ------- + ArrayLike + The array passed to ``__init__``. + """ + return self.value + + +class ConstantTemporalProcess: + """Temporal process stub that returns zeros with the requested shape.""" + + step_size = 1 + + def sample(self, n_timepoints, n_processes=1, **kwargs): + """ + Return a correctly shaped zero trajectory. + + Parameters + ---------- + n_timepoints + Number of time points. + n_processes + Number of parallel processes. + + Returns + ------- + jnp.ndarray + Zeros of shape ``(n_timepoints, n_processes)``. + """ + return jnp.zeros((n_timepoints, n_processes)) + + +@pytest.fixture +def wrong_shape_temporal_process_cls(): + """ + Class producing a temporal-process stub that returns a fixed wrong shape. + + Returns + ------- + type + The ``WrongShapeTemporalProcess`` class. Instantiate with the array + the stub should return from ``sample``. + """ + return WrongShapeTemporalProcess + + +@pytest.fixture +def constant_temporal_process(): + """ + Temporal-process stub that returns a correctly shaped zero trajectory. + + Returns + ------- + ConstantTemporalProcess + Instance whose ``sample`` returns zeros of the requested shape. + """ + return ConstantTemporalProcess() diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 08b04b81..21049688 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -18,13 +18,15 @@ load_synthetic_daily_hospital_admissions, load_synthetic_daily_infections, load_synthetic_true_parameters, + load_synthetic_weekly_hospital_admissions, ) from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import AR1 +from pyrenew.latent import AR1, StepwiseTemporalProcess from pyrenew.latent.population_infections import PopulationInfections -from pyrenew.model import PyrenewBuilder +from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import NegativeBinomialNoise, PopulationCounts from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import MMWR_WEEK @pytest.fixture(scope="module") @@ -80,6 +82,19 @@ def daily_ed() -> pl.DataFrame: return load_synthetic_daily_ed_visits() +@pytest.fixture(scope="module") +def weekly_hosp() -> pl.DataFrame: + """ + Load synthetic weekly (MMWR epiweek) hospital admissions. + + Returns + ------- + pl.DataFrame + Columns: week_end, weekly_hosp_admits, location, pop. + """ + return load_synthetic_weekly_hospital_admissions() + + @pytest.fixture(scope="module") def hosp_delay_pmf() -> jnp.ndarray: """ @@ -135,7 +150,7 @@ def he_model( hosp_delay_pmf: jnp.ndarray, ed_delay_pmf: jnp.ndarray, ed_day_of_week_effects: jnp.ndarray, -) -> PyrenewBuilder: +) -> MultiSignalModel: """ Build a PopulationInfections model with hospital + ED observation processes. @@ -188,3 +203,146 @@ def he_model( builder.add_observation(ed_obs) return builder.build() + + +@pytest.fixture(scope="module") +def he_weekly_rt_model( + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """ + Build a PopulationInfections model with weekly-parameterized R(t). + + Same observation configuration as ``he_weekly_model`` (weekly hospital + admissions on the MMWR epiweek grid + daily ED visits with a day-of-week + effect), but R(t) is sampled weekly and broadcast to daily via + ``StepwiseTemporalProcess`` with calendar-week alignment. This mirrors the + production pyrenew-hew configuration. + + Parameters + ---------- + hosp_delay_pmf : jnp.ndarray + Infection-to-hospitalization delay PMF. + ed_delay_pmf : jnp.ndarray + Infection-to-ED-visit delay PMF. + ed_day_of_week_effects : jnp.ndarray + Day-of-week multipliers used in synthetic ED generation. + + Returns + ------- + MultiSignalModel + Built model ready for fitting. + """ + gen_int_pmf = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] + ) + + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + ) + + hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + aggregation="weekly", + reporting_schedule="regular", + week=MMWR_WEEK, + ) + builder.add_observation(hospital_obs) + + ed_obs = PopulationCounts( + name="ed", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + builder.add_observation(ed_obs) + + return builder.build() + + +@pytest.fixture(scope="module") +def he_weekly_model( + hosp_delay_pmf: jnp.ndarray, + ed_delay_pmf: jnp.ndarray, + ed_day_of_week_effects: jnp.ndarray, +) -> MultiSignalModel: + """ + Build a PopulationInfections model with WEEKLY hospital + DAILY ED observations. + + The hospital observation is aggregated to MMWR epiweeks + (Sunday-Saturday, via ``MMWR_WEEK``); the ED observation stays + daily with a day-of-week effect. R(t) is parametrized at the + finest observation cadence (daily) per the coherence rules for + mixed-cadence models. + + Parameters + ---------- + hosp_delay_pmf : jnp.ndarray + Infection-to-hospitalization delay PMF. + ed_delay_pmf : jnp.ndarray + Infection-to-ED-visit delay PMF. + ed_day_of_week_effects : jnp.ndarray + Day-of-week multipliers used in synthetic ED generation. + + Returns + ------- + MultiSignalModel + Built model ready for fitting. + """ + gen_int_pmf = jnp.array( + [0.6326975, 0.2327564, 0.0856263, 0.03150015, 0.01158826, 0.00426308, 0.0015683] + ) + + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", gen_int_pmf), + I0_rv=DistributionalVariable("I0", dist.Beta(1, 10)), + log_rt_time_0_rv=DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)), + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + ) + + hospital_obs = PopulationCounts( + name="hospital", + ascertainment_rate_rv=DistributionalVariable("ihr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("hosp_delay", hosp_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("hosp_conc", dist.LogNormal(5.0, 1.0)) + ), + aggregation="weekly", + reporting_schedule="regular", + week=MMWR_WEEK, + ) + builder.add_observation(hospital_obs) + + ed_obs = PopulationCounts( + name="ed", + ascertainment_rate_rv=DistributionalVariable("iedr", dist.Beta(1, 100)), + delay_distribution_rv=DeterministicPMF("ed_delay", ed_delay_pmf), + noise=NegativeBinomialNoise( + DistributionalVariable("ed_conc", dist.LogNormal(4.0, 1.0)) + ), + day_of_week_rv=DeterministicVariable("ed_dow", ed_day_of_week_effects), + ) + builder.add_observation(ed_obs) + + return builder.build() diff --git a/test/integration/test_population_infections_he.py b/test/integration/test_population_infections_he.py index 12ac6553..16976028 100644 --- a/test/integration/test_population_infections_he.py +++ b/test/integration/test_population_infections_he.py @@ -2,12 +2,14 @@ Integration test: PopulationInfections H+E model with posterior recovery. Fits a PopulationInfections model with hospital admissions and ED visit -observation processes to synthetic 120-day CA data, then checks that +observation processes to synthetic 126-day CA data, then checks that posterior estimates recover known true parameters. """ from __future__ import annotations +from datetime import date + import arviz as az import jax import jax.numpy as jnp @@ -21,7 +23,7 @@ pytestmark = pytest.mark.integration -N_DAYS_FIT = 120 +N_DAYS_FIT = 126 NUM_WARMUP = 500 NUM_SAMPLES = 500 NUM_CHAINS = 4 @@ -48,9 +50,9 @@ def test_data_shapes( daily_infections : pl.DataFrame True infections and R(t). """ - assert len(daily_hosp) == 120 - assert len(daily_ed) == 120 - assert len(daily_infections) == 120 + assert len(daily_hosp) == 126 + assert len(daily_ed) == 126 + assert len(daily_infections) == 126 assert "daily_hosp_admits" in daily_hosp.columns assert "ed_visits" in daily_ed.columns assert "true_rt" in daily_infections.columns @@ -130,7 +132,6 @@ def fitted_model( ) population_size = float(daily_hosp["pop"][0]) - first_dow = 0 # 2023-11-06 is a Monday he_model.run( num_warmup=NUM_WARMUP, @@ -139,11 +140,9 @@ def fitted_model( mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, n_days_post_init=N_DAYS_FIT, population_size=population_size, + obs_start_date=date(2023, 11, 6), hospital={"obs": hosp_obs}, - ed={ - "obs": ed_obs, - "first_day_dow": he_model.compute_first_day_dow(first_dow), - }, + ed={"obs": ed_obs}, ) samples = he_model.mcmc.get_samples() diff --git a/test/integration/test_population_infections_he_weekly.py b/test/integration/test_population_infections_he_weekly.py new file mode 100644 index 00000000..048b0e6d --- /dev/null +++ b/test/integration/test_population_infections_he_weekly.py @@ -0,0 +1,458 @@ +""" +Integration test: PopulationInfections H+E model with WEEKLY hospital admissions. + +Same structure as ``test_population_infections_he`` but the hospital +signal is aggregated to MMWR epiweeks. Fits the mixed-cadence model +(weekly hospital + daily ED) to synthetic 126-day CA data and checks +posterior recovery. +""" + +from __future__ import annotations + +from datetime import date + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel +from pyrenew.time import MMWR_WEEK + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +# First observation day of the synthetic data. 2023-11-05 is a Sunday (ISO dow = 6). +OBS_START_DATE = date(2023, 11, 5) + + +def _build_hospital_obs_on_period_grid( + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """ + Build a dense weekly-observation array on the model's period grid. + + Pads ``n_pre`` NaN values at the front for periods that precede the + first observed week (periods that overlap the initialization window + and any pre-observation gap). + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + weekly_values : jnp.ndarray + Observed weekly hospital admissions, one per MMWR epiweek. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + jnp.ndarray + Dense array of shape ``(n_periods,)`` with NaN for unobserved + periods and observed counts for periods covered by + ``weekly_values``. + """ + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = hosp._compute_period_offset(first_day_dow, hosp.week) + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) + + +class TestDataAssembly: + """Verify synthetic data can be loaded and aligned to the weekly model.""" + + def test_data_shapes( + self, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + daily_infections: pl.DataFrame, + ) -> None: + """ + Verify synthetic data files have expected row counts and columns. + + Parameters + ---------- + weekly_hosp : pl.DataFrame + Weekly MMWR-epiweek hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + daily_infections : pl.DataFrame + True infections and R(t). + """ + assert len(weekly_hosp) == 18 + assert len(daily_ed) == 126 + assert len(daily_infections) == 126 + assert "weekly_hosp_admits" in weekly_hosp.columns + assert "ed_visits" in daily_ed.columns + assert "true_rt" in daily_infections.columns + + def test_hospital_is_weekly_regular( + self, + he_weekly_model: MultiSignalModel, + ) -> None: + """ + Verify the hospital observation is weekly-aggregated and MMWR-anchored. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + """ + h = he_weekly_model.observations["hospital"] + assert h.aggregation == "weekly" + assert h.reporting_schedule == "regular" + assert h.week == MMWR_WEEK + + def test_ed_stays_daily( + self, + he_weekly_model: MultiSignalModel, + ) -> None: + """ + Verify the ED observation remains daily with a day-of-week effect. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + """ + ed = he_weekly_model.observations["ed"] + assert ed.aggregation == "daily" + assert ed.day_of_week_rv is not None + + def test_weekly_obs_alignment( + self, + he_weekly_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Verify the weekly obs array has the correct dense-on-period-grid shape. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + first_day_dow = he_weekly_model._resolve_first_day_dow(OBS_START_DATE) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_model, weekly_values, first_day_dow + ) + + n_init = he_weekly_model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + hosp = he_weekly_model.observations["hospital"] + offset = hosp._compute_period_offset(first_day_dow, hosp.week) + n_periods = (n_total - offset) // hosp.aggregation_period + + assert hosp_obs.shape == (n_periods,) + assert jnp.isnan(hosp_obs[0]) + assert not jnp.isnan(hosp_obs[-1]) + assert int((~jnp.isnan(hosp_obs)).sum()) == len(weekly_hosp) + + +class TestModelFit: + """Fit the weekly H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( + self, + he_weekly_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """ + Fit the mixed-cadence H+E model to synthetic data via MCMC. + + Parameters + ---------- + he_weekly_model : MultiSignalModel + Built model. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + + Returns + ------- + MultiSignalModel + Model with MCMC results attached. + """ + first_day_dow = he_weekly_model._resolve_first_day_dow(OBS_START_DATE) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_model, weekly_values, first_day_dow + ) + + ed_obs = he_weekly_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(weekly_hosp["pop"][0]) + + he_weekly_model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, + ) + + samples = he_weekly_model.mcmc.get_samples() + jax.block_until_ready(samples) + return he_weekly_model + + @pytest.fixture(scope="class") + def posterior_dt( + self, + fitted_model: MultiSignalModel, + ): + """ + Convert MCMC samples to an ArviZ DataTree, trimming the init period. + + The hospital signal lives on the weekly period grid (dim ``week``); + the daily-scale sites (``time``) are trimmed by ``n_init``. + + Parameters + ---------- + fitted_model : MultiSignalModel + Model with MCMC results. + + Returns + ------- + xarray.DataTree + ArviZ DataTree with posterior group. + """ + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): + """ + Trim the initialization period from the ``time`` dimension only. + + Parameters + ---------- + ds + Dataset to trim. + + Returns + ------- + xarray.Dataset + Dataset with ``time`` sliced to ``[n_init:]``; ``week`` + and other dims pass through unchanged. + """ + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """ + Check that core parameters have acceptable Rhat and ESS. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """ + Check that the 90% credible interval for R(t) covers the true value + for at least 80% of time points. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_infections : pl.DataFrame + True R(t) trajectory. + """ + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + assert coverage >= 0.80, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 80%" + ) + + def test_infection_trajectory_shape( + self, + posterior_dt, + ) -> None: + """ + Check the posterior infection trajectory has correct shape and is positive. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + infections = posterior_dt.posterior["latent_infections"] + assert infections.sizes["time"] == N_DAYS_FIT + assert (infections.values > 0).all() + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """ + Check that posterior median IHR and IEDR are within a factor + of 5 of the true values. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + true_params : dict + Ground-truth parameter dictionary. + """ + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + def test_hospital_predicted_weekly_grid( + self, + posterior_dt, + weekly_hosp: pl.DataFrame, + ) -> None: + """ + Check that hospital posterior predictions live on the weekly grid + and have plausible magnitude relative to observed weekly counts. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + """ + hosp_pred = posterior_dt.posterior["hospital_predicted"] + # n_periods from the weekly grid must be >= the number of observed weeks. + assert hosp_pred.sizes["week"] >= len(weekly_hosp) + + hosp_pred_median = float(hosp_pred.median(dim=["chain", "draw", "week"]).values) + hosp_obs_mean = float(weekly_hosp["weekly_hosp_admits"].mean()) + + assert hosp_pred_median > 0, "Hospital predictions should be positive" + assert hosp_obs_mean / 10 <= hosp_pred_median <= hosp_obs_mean * 10 + + def test_hospital_predicted_daily_has_daily_grid( + self, + posterior_dt, + ) -> None: + """ + Check that the ``hospital_predicted_daily`` deterministic site covers + the daily shared axis. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + hosp_pred_daily = posterior_dt.posterior["hospital_predicted_daily"] + assert hosp_pred_daily.sizes["time"] == N_DAYS_FIT + + def test_ed_predicted_reasonable( + self, + posterior_dt, + daily_ed: pl.DataFrame, + ) -> None: + """ + Check that daily ED posterior predictions have the right shape and + plausible magnitude relative to observed daily ED visits. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_ed : pl.DataFrame + Daily ED visits. + """ + ed_pred = posterior_dt.posterior["ed_predicted"] + assert ed_pred.sizes["time"] == N_DAYS_FIT + + ed_pred_median = float(ed_pred.median(dim=["chain", "draw", "time"]).values) + ed_obs_mean = float(daily_ed["ed_visits"].mean()) + + assert ed_pred_median > 0, "ED predictions should be positive" + assert ed_obs_mean / 10 <= ed_pred_median <= ed_obs_mean * 10 + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/integration/test_population_infections_he_weekly_rt.py b/test/integration/test_population_infections_he_weekly_rt.py new file mode 100644 index 00000000..f1e438ac --- /dev/null +++ b/test/integration/test_population_infections_he_weekly_rt.py @@ -0,0 +1,383 @@ +""" +Integration test: PopulationInfections H+E model with WEEKLY R(t). + +Mirrors ``test_population_infections_he_weekly`` but parameterizes R(t) +weekly via ``StepwiseTemporalProcess(step_size=7, alignment="calendar_week")`` +and broadcasts to the daily renewal axis. This is the production +pyrenew-hew configuration: weekly hospital admissions + daily ED visits + +weekly calendar-aligned R(t). +""" + +from __future__ import annotations + +from datetime import date + +import arviz as az +import jax +import jax.numpy as jnp +import jax.random as random +import numpy as np +import numpyro +import polars as pl +import pytest + +from pyrenew.model import MultiSignalModel + +pytestmark = pytest.mark.integration + + +N_DAYS_FIT = 126 +NUM_WARMUP = 500 +NUM_SAMPLES = 500 +NUM_CHAINS = 4 +# First observation day of the synthetic data. 2023-11-05 is a Sunday (ISO dow = 6). +OBS_START_DATE = date(2023, 11, 5) +WEEK_START_DOW = 6 + + +def _build_hospital_obs_on_period_grid( + model: MultiSignalModel, + weekly_values: jnp.ndarray, + first_day_dow: int, +) -> jnp.ndarray: + """ + Build a dense weekly-observation array on the model's period grid. + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + weekly_values : jnp.ndarray + Observed weekly hospital admissions, one per MMWR epiweek. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + jnp.ndarray + Dense array of shape ``(n_periods,)`` with NaN for unobserved + periods and observed counts for periods covered by + ``weekly_values``. + """ + hosp = model.observations["hospital"] + n_init = model.latent.n_initialization_points + n_total = n_init + N_DAYS_FIT + offset = hosp._compute_period_offset(first_day_dow, hosp.week) + n_periods = (n_total - offset) // hosp.aggregation_period + n_pre = n_periods - len(weekly_values) + return jnp.concatenate([jnp.full(n_pre, jnp.nan, dtype=jnp.float32), weekly_values]) + + +def _expected_n_coarse(model: MultiSignalModel, first_day_dow: int) -> int: + """ + Expected number of coarse R(t) samples for calendar-week alignment. + + Mirrors ``StepwiseTemporalProcess._resolve_n_coarse`` for ``step_size=7`` + and ``alignment="calendar_week"``. + + Parameters + ---------- + model : MultiSignalModel + Built model exposing ``latent.n_initialization_points``. + first_day_dow : int + Day-of-week index of element 0 of the shared daily axis. + + Returns + ------- + int + Number of weekly Rt samples covering the daily model axis. + """ + n_total = model.latent.n_initialization_points + N_DAYS_FIT + trim = (first_day_dow - WEEK_START_DOW) % 7 + return (n_total + trim + 6) // 7 + + +class TestPriorPredictiveStructure: + """Verify the weekly-Rt graph records a coarse trajectory at the right shape.""" + + def test_coarse_rt_recorded( + self, + he_weekly_rt_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> None: + """ + Single prior-predictive sample exposes ``log_rt_single_coarse`` at the + expected coarse length and a daily-length broadcast Rt. + + Parameters + ---------- + he_weekly_rt_model : MultiSignalModel + Built model with calendar-aligned weekly Rt. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + """ + first_day_dow = he_weekly_rt_model._resolve_first_day_dow(OBS_START_DATE) + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_rt_model, weekly_values, first_day_dow + ) + ed_obs = he_weekly_rt_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + population_size = float(weekly_hosp["pop"][0]) + + with numpyro.handlers.seed(rng_seed=0): + with numpyro.handlers.trace() as trace: + he_weekly_rt_model.sample( + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, + ) + + n_total = he_weekly_rt_model.latent.n_initialization_points + N_DAYS_FIT + n_coarse = _expected_n_coarse(he_weekly_rt_model, first_day_dow) + + coarse = trace["log_rt_single_coarse"]["value"] + daily = trace["PopulationInfections::log_rt_single"]["value"] + + assert coarse.shape == (n_coarse, 1) + assert daily.shape == (n_total, 1) + assert n_coarse < n_total + + # Each block of 7 daily values past the leading partial week should be + # constant (the calendar-week broadcast invariant). + partial_len = (WEEK_START_DOW - first_day_dow) % 7 + if partial_len > 0: + assert jnp.allclose(daily[:partial_len], daily[0]) + first_full = partial_len + assert jnp.allclose(daily[first_full : first_full + 7], daily[first_full]) + + +class TestModelFit: + """Fit the weekly-Rt H+E model and check posterior recovery.""" + + @pytest.fixture(scope="class") + def fitted_model( + self, + he_weekly_rt_model: MultiSignalModel, + weekly_hosp: pl.DataFrame, + daily_ed: pl.DataFrame, + ) -> MultiSignalModel: + """ + Fit the weekly-Rt H+E model to synthetic data via MCMC. + + Parameters + ---------- + he_weekly_rt_model : MultiSignalModel + Built model with calendar-aligned weekly Rt. + weekly_hosp : pl.DataFrame + Weekly hospital admissions. + daily_ed : pl.DataFrame + Daily ED visits. + + Returns + ------- + MultiSignalModel + Model with MCMC results attached. + """ + first_day_dow = he_weekly_rt_model._resolve_first_day_dow(OBS_START_DATE) + + weekly_values = jnp.array( + weekly_hosp["weekly_hosp_admits"].to_numpy(), dtype=jnp.float32 + ) + hosp_obs = _build_hospital_obs_on_period_grid( + he_weekly_rt_model, weekly_values, first_day_dow + ) + + ed_obs = he_weekly_rt_model.pad_observations( + jnp.array(daily_ed["ed_visits"].to_numpy(), dtype=jnp.float32) + ) + + population_size = float(weekly_hosp["pop"][0]) + + he_weekly_rt_model.run( + num_warmup=NUM_WARMUP, + num_samples=NUM_SAMPLES, + rng_key=random.PRNGKey(42), + mcmc_args={"num_chains": NUM_CHAINS, "progress_bar": False}, + n_days_post_init=N_DAYS_FIT, + population_size=population_size, + obs_start_date=OBS_START_DATE, + hospital={"obs": hosp_obs}, + ed={"obs": ed_obs}, + ) + + samples = he_weekly_rt_model.mcmc.get_samples() + jax.block_until_ready(samples) + return he_weekly_rt_model + + @pytest.fixture(scope="class") + def posterior_dt( + self, + fitted_model: MultiSignalModel, + ): + """ + Convert MCMC samples to an ArviZ DataTree, trimming the init period. + + Parameters + ---------- + fitted_model : MultiSignalModel + Model with MCMC results. + + Returns + ------- + xarray.DataTree + ArviZ DataTree with posterior group. + """ + n_init = fitted_model.latent.n_initialization_points + dt = az.from_numpyro( + fitted_model.mcmc, + dims={ + "latent_infections": ["time"], + "PopulationInfections::infections_aggregate": ["time"], + "PopulationInfections::log_rt_single": ["time", "dummy"], + "PopulationInfections::rt_single": ["time", "dummy"], + "log_rt_single_coarse": ["coarse_week", "dummy"], + "hospital_predicted_daily": ["time"], + "hospital_predicted": ["week"], + "ed_predicted": ["time"], + }, + ) + + def trim_init(ds): + """ + Trim the initialization period from the ``time`` dimension only. + + Parameters + ---------- + ds + Dataset to trim. + + Returns + ------- + xarray.Dataset + Dataset with ``time`` sliced to ``[n_init:]``; other dims + pass through unchanged. + """ + if "time" in ds.dims: + ds = ds.isel(time=slice(n_init, None)) + ds = ds.assign_coords(time=range(ds.sizes["time"])) + return ds + + return dt.map_over_datasets(trim_init) + + def test_mcmc_convergence( + self, + posterior_dt, + ) -> None: + """ + Check that core parameters have acceptable Rhat and ESS. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + summary = az.summary( + posterior_dt, + var_names=["I0", "log_rt_time_0", "ihr", "iedr"], + ) + rhat = summary["r_hat"].astype(float) + ess = summary["ess_bulk"].astype(float) + assert (rhat < 1.05).all(), f"Rhat exceeded 1.05:\n{summary[rhat >= 1.05]}" + assert (ess > 100).all(), f"ESS_bulk below 100:\n{summary[ess <= 100]}" + + def test_coarse_rt_posterior_shape( + self, + fitted_model: MultiSignalModel, + posterior_dt, + ) -> None: + """ + Check the coarse Rt site lives on the weekly cadence in the posterior. + + Parameters + ---------- + fitted_model : MultiSignalModel + Fitted model. + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + """ + first_day_dow = fitted_model._resolve_first_day_dow(OBS_START_DATE) + n_coarse = _expected_n_coarse(fitted_model, first_day_dow) + + coarse = posterior_dt.posterior["log_rt_single_coarse"] + assert coarse.sizes["coarse_week"] == n_coarse + + def test_rt_posterior_covers_truth( + self, + posterior_dt, + daily_infections: pl.DataFrame, + ) -> None: + """ + Check that the 90% credible interval for R(t) covers the true value + for at least 80% of time points. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + daily_infections : pl.DataFrame + True R(t) trajectory. + """ + rt_posterior = posterior_dt.posterior["PopulationInfections::rt_single"] + rt_q05 = rt_posterior.quantile(0.05, dim=["chain", "draw"]).values + rt_q95 = rt_posterior.quantile(0.95, dim=["chain", "draw"]).values + + true_rt = daily_infections["true_rt"].to_numpy() + + if rt_q05.ndim > 1: + rt_q05 = rt_q05.squeeze() + rt_q95 = rt_q95.squeeze() + + n_compare = min(len(true_rt), len(rt_q05)) + covered = (true_rt[:n_compare] >= rt_q05[:n_compare]) & ( + true_rt[:n_compare] <= rt_q95[:n_compare] + ) + coverage = float(np.mean(covered)) + assert coverage >= 0.80, ( + f"R(t) 90% CI coverage was {coverage:.1%}, expected >= 80%" + ) + + def test_ascertainment_rates_recover_order_of_magnitude( + self, + posterior_dt, + true_params: dict, + ) -> None: + """ + Check that posterior median IHR and IEDR are within a factor + of 5 of the true values. + + Parameters + ---------- + posterior_dt : xarray.DataTree + ArviZ DataTree with posterior group. + true_params : dict + Ground-truth parameter dictionary. + """ + true_ihr = true_params["hospitalizations"]["ihr"] + true_iedr = true_params["ed_visits"]["iedr"] + + ihr_median = float( + posterior_dt.posterior["ihr"].median(dim=["chain", "draw"]).values + ) + iedr_median = float( + posterior_dt.posterior["iedr"].median(dim=["chain", "draw"]).values + ) + + assert true_ihr / 5 <= ihr_median <= true_ihr * 5, ( + f"IHR median {ihr_median:.4f} not within 5x of true {true_ihr}" + ) + assert true_iedr / 5 <= iedr_median <= true_iedr * 5, ( + f"IEDR median {iedr_median:.4f} not within 5x of true {true_iedr}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integration"]) diff --git a/test/test_ar_process.py b/test/test_ar_process.py index 2af800dd..5bef827d 100755 --- a/test/test_ar_process.py +++ b/test/test_ar_process.py @@ -243,6 +243,23 @@ def test_ar_process_asymptotics(ar_inits, autoreg, noise_sd, n): :n ] + # Closed-form stationary standard deviation for the AR(p) process. + # The previous assertion used `3 * noise_sd`, which is only correct when + # the stationary variance equals the innovation variance (p=0); for AR(1) + # and AR(2) with non-zero autoreg, the stationary SD is strictly larger. + sigma2 = noise_sd**2 + if order == 1: + stationary_sd = jnp.sqrt(sigma2 / (1 - autoreg[0] ** 2)) + elif order == 2: + phi1, phi2 = autoreg[0], autoreg[1] + stationary_sd = jnp.sqrt( + sigma2 * (1 - phi2) / ((1 + phi2) * ((1 - phi2) ** 2 - phi1**2)) + ) + else: + raise NotImplementedError( + f"Stationary SD for AR order {order} not implemented in this test." + ) + with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it @@ -255,4 +272,4 @@ def test_ar_process_asymptotics(ar_inits, autoreg, noise_sd, n): ) assert_array_almost_equal(long_ts[:order], expected_first_entries) - assert jnp.abs(long_ts[-1]) < 3 * noise_sd + assert jnp.abs(long_ts[-1]) < 3 * stationary_sd diff --git a/test/test_datagen_he_CA_120.py b/test/test_datagen_he_CA_126.py similarity index 70% rename from test/test_datagen_he_CA_120.py rename to test/test_datagen_he_CA_126.py index 87aef8e9..9ab12e3d 100644 --- a/test/test_datagen_he_CA_120.py +++ b/test/test_datagen_he_CA_126.py @@ -1,5 +1,5 @@ """ -Unit tests for the datagen_he_CA_120 helper functions. +Unit tests for the datagen_he_CA_126 helper functions. """ import json @@ -8,11 +8,11 @@ import numpy as np import polars as pl -import pyrenew.datasets.datagen_he_CA_120 as datagen_mod -from pyrenew.datasets.datagen_he_CA_120 import ( - aggregate_to_epiweeks, +import pyrenew.datasets.datagen_he_CA_126 as datagen_mod +from pyrenew.datasets.datagen_he_CA_126 import ( apply_day_of_week_effects, build_true_rt, + build_weekly_hosp_frame, generate, run_renewal, sample_negbinom, @@ -23,9 +23,9 @@ class TestBuildTrueRt: """Tests for build_true_rt.""" def test_length(self): - """Test that the output has 120 days.""" + """Test that the output has 126 days.""" rt = build_true_rt() - assert len(rt) == 120 + assert len(rt) == 126 def test_starting_value(self): """Test that Rt starts near 1.2.""" @@ -35,8 +35,9 @@ def test_starting_value(self): def test_phase_endpoints(self): """Test that phase transitions occur at expected values.""" rt = build_true_rt() - assert rt[59] < 0.85 - assert rt[60] > 0.79 + # Phase 1 ends at index 62 (last day of 63-day decline), phase 2 starts at 63. + assert rt[62] < 0.85 + assert rt[63] > 0.79 class TestRunRenewal: @@ -71,22 +72,22 @@ def test_uniform_effects(self): """Test that uniform effects leave values unchanged.""" values = np.array([10.0, 20.0, 30.0]) dow = np.ones(7) - result = apply_day_of_week_effects(values, dow, first_dow=0) + result = apply_day_of_week_effects(values, dow, first_day_dow=0) np.testing.assert_allclose(result, values) def test_known_pattern(self): """Test that known day-of-week effects are applied correctly.""" values = np.ones(7) * 100.0 dow = np.array([1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 0.5]) - result = apply_day_of_week_effects(values, dow, first_dow=0) + result = apply_day_of_week_effects(values, dow, first_day_dow=0) np.testing.assert_allclose(result[5], 50.0) np.testing.assert_allclose(result[6], 50.0) def test_offset_start(self): - """Test that first_dow correctly offsets the pattern.""" + """Test that first_day_dow correctly offsets the pattern.""" values = np.ones(3) * 10.0 dow = np.array([2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) - result = apply_day_of_week_effects(values, dow, first_dow=6) + result = apply_day_of_week_effects(values, dow, first_day_dow=6) np.testing.assert_allclose(result[0], 10.0) np.testing.assert_allclose(result[1], 20.0) @@ -116,25 +117,39 @@ def test_handles_near_zero_mu(self): assert result.shape == mu.shape -class TestAggregateToEpiweeks: - """Tests for aggregate_to_epiweeks.""" +class TestBuildWeeklyHospFrame: + """Tests for build_weekly_hosp_frame.""" - def test_complete_weeks_only(self): - """Test that only complete 7-day weeks are returned.""" - daily = np.ones(20) + def test_length_matches_input(self): + """Frame length equals the number of weekly values supplied.""" + weekly = np.array([100.0, 200.0, 300.0]) start = date(2023, 11, 6) - result = aggregate_to_epiweeks(daily, start) + result = build_weekly_hosp_frame(weekly, start) assert isinstance(result, pl.DataFrame) - assert "weekly_hosp_admits" in result.columns - for val in result["weekly_hosp_admits"].to_list(): - assert val == 7.0 + assert len(result) == 3 - def test_weekly_sum(self): - """Test that weekly sums are correct.""" - daily = np.arange(1, 15, dtype=float) + def test_weekly_values_preserved(self): + """Input weekly values appear in the weekly_hosp_admits column.""" + weekly = np.array([100.0, 200.0, 300.0]) + start = date(2023, 11, 6) + result = build_weekly_hosp_frame(weekly, start) + assert result["weekly_hosp_admits"].to_list() == [100.0, 200.0, 300.0] + + def test_first_week_end_is_saturday(self): + """First week_end is the Saturday of the first complete MMWR epiweek.""" + # 2023-11-06 is a Monday; first MMWR Saturday after (Sun-start week) is 2023-11-18. + weekly = np.array([1.0, 2.0]) + start = date(2023, 11, 6) + result = build_weekly_hosp_frame(weekly, start) + assert result["week_end"].to_list() == [date(2023, 11, 18), date(2023, 11, 25)] + + def test_first_week_end_sunday_start(self): + """A Sunday start yields the immediately-following Saturday as first week_end.""" + # 2023-11-05 is a Sunday; first full MMWR week ends Saturday 2023-11-11. + weekly = np.array([5.0]) start = date(2023, 11, 5) - result = aggregate_to_epiweeks(daily, start) - assert len(result) >= 1 + result = build_weekly_hosp_frame(weekly, start) + assert result["week_end"].to_list() == [date(2023, 11, 11)] class TestGenerate: @@ -164,7 +179,7 @@ def test_generate_true_parameters_content(self, tmp_path, monkeypatch): params = json.load(f) assert params["population"] == datagen_mod.POPULATION - assert params["n_days"] == 120 + assert params["n_days"] == 126 assert "hospitalizations" in params assert "ed_visits" in params @@ -174,7 +189,7 @@ def test_generate_daily_infections_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_infections.csv") - assert len(df) == 120 + assert len(df) == 126 assert "true_infections" in df.columns assert "true_rt" in df.columns @@ -184,7 +199,7 @@ def test_generate_hospital_admissions_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_hospital_admissions.csv") - assert len(df) == 120 + assert len(df) == 126 assert "daily_hosp_admits" in df.columns def test_generate_ed_visits_shape(self, tmp_path, monkeypatch): @@ -193,5 +208,5 @@ def test_generate_ed_visits_shape(self, tmp_path, monkeypatch): generate() df = pl.read_csv(tmp_path / "daily_ed_visits.csv") - assert len(df) == 120 + assert len(df) == 126 assert "ed_visits" in df.columns diff --git a/test/test_datasets_synthetic.py b/test/test_datasets_synthetic.py index 52872d78..ca41e5c8 100644 --- a/test/test_datasets_synthetic.py +++ b/test/test_datasets_synthetic.py @@ -1,5 +1,5 @@ """ -Unit tests for synthetic CA 120-day dataset loaders. +Unit tests for synthetic CA 126-day dataset loaders. """ import polars as pl @@ -46,9 +46,9 @@ def test_population_is_positive(self): assert params["population"] > 0 def test_n_days_matches_data(self): - """Test that n_days is 120.""" + """Test that n_days is 126.""" params = load_synthetic_true_parameters() - assert params["n_days"] == 120 + assert params["n_days"] == 126 def test_hospitalization_params(self): """Test that hospitalization sub-dict has expected keys.""" @@ -83,10 +83,10 @@ def test_has_expected_columns(self): assert "true_infections" in df.columns assert "true_rt" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_infections() - assert len(df) == 120 + assert len(df) == 126 def test_infections_are_positive(self): """Test that true infections are positive.""" @@ -114,10 +114,10 @@ def test_has_expected_columns(self): assert "daily_hosp_admits" in df.columns assert "pop" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_hospital_admissions() - assert len(df) == 120 + assert len(df) == 126 def test_admits_are_non_negative(self): """Test that hospital admissions are non-negative.""" @@ -139,10 +139,10 @@ def test_has_expected_columns(self): assert "date" in df.columns assert "ed_visits" in df.columns - def test_has_120_rows(self): - """Test that the dataset has 120 days.""" + def test_has_126_rows(self): + """Test that the dataset has 126 days.""" df = load_synthetic_daily_ed_visits() - assert len(df) == 120 + assert len(df) == 126 def test_visits_are_non_negative(self): """Test that ED visits are non-negative.""" diff --git a/test/test_observation_counts.py b/test/test_observation_counts.py index 82bea1a3..b00f842e 100644 --- a/test/test_observation_counts.py +++ b/test/test_observation_counts.py @@ -15,6 +15,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable +from pyrenew.time import MMWR_WEEK from test.test_helpers import create_mock_infections @@ -283,6 +284,7 @@ def test_non_contiguous_subpop_indices(self): ascertainment_rate_rv=DeterministicVariable("ihr", 1.0), delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), + reporting_schedule="irregular", ) # 5 subpopulations with distinct infection levels @@ -291,13 +293,13 @@ def test_non_contiguous_subpop_indices(self): for k in range(5): infections = infections.at[:, k].set((k + 1) * 100.0) - times = jnp.array([10, 10, 10]) + period_end_times = jnp.array([10, 10, 10]) subpop_indices = jnp.array([0, 2, 4]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, ) @@ -550,18 +552,19 @@ def test_counts_by_subpop_2d_broadcasting(self): delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), right_truncation_rv=DeterministicPMF("rt_delay", rt_pmf), + reporting_schedule="irregular", ) n_days = 10 n_subpops = 3 infections = jnp.ones((n_days, n_subpops)) * 100 - times = jnp.array([0, 8, 9]) + period_end_times = jnp.array([0, 8, 9]) subpop_indices = jnp.array([0, 1, 2]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, right_truncation_offset=0, @@ -739,18 +742,19 @@ def test_counts_by_subpop_2d_broadcasting(self): delay_distribution_rv=DeterministicPMF("delay", delay_pmf), noise=PoissonNoise(), day_of_week_rv=DeterministicVariable("dow", dow_effect), + reporting_schedule="irregular", ) n_days = 14 n_subpops = 3 infections = jnp.ones((n_days, n_subpops)) * 100 - times = jnp.array([0, 1, 7]) + period_end_times = jnp.array([0, 1, 7]) subpop_indices = jnp.array([0, 1, 2]) with numpyro.handlers.seed(rng_seed=42): result = process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, first_day_dow=0, @@ -843,21 +847,641 @@ def test_counts_by_subpop_dow_without_offset_raises(self): delay_distribution_rv=DeterministicPMF("delay", jnp.array([1.0])), noise=PoissonNoise(), day_of_week_rv=DeterministicVariable("dow", dow_effect), + reporting_schedule="irregular", ) infections = jnp.ones((14, 2)) * 100 - times = jnp.array([0, 1]) + period_end_times = jnp.array([0, 1]) subpop_indices = jnp.array([0, 1]) with numpyro.handlers.seed(rng_seed=42): with pytest.raises(ValueError, match="first_day_dow is required"): process.sample( infections=infections, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=None, first_day_dow=None, ) +# =================================================================== +# PopulationCounts with aggregation: construction-time validation +# =================================================================== + + +class TestPopulationCountsAggregationConstruction: + """Construction-time validation for the aggregation parameters.""" + + def _make(self, simple_delay_pmf, **kwargs): + """ + Build a PopulationCounts with a stub noise model and optional overrides. + + Returns + ------- + PopulationCounts + A PopulationCounts instance using the supplied delay PMF and + any additional constructor overrides passed via ``**kwargs``. + """ + return PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + **kwargs, + ) + + def test_default_construction_is_daily_regular(self, simple_delay_pmf): + """Default constructor yields aggregation='daily', reporting_schedule='regular'.""" + process = self._make(simple_delay_pmf) + assert process.aggregation == "daily" + assert process.reporting_schedule == "regular" + assert process.week is None + + def test_weekly_requires_week(self, simple_delay_pmf): + """aggregation='weekly' without week must raise.""" + with pytest.raises(ValueError, match="week is required"): + self._make(simple_delay_pmf, aggregation="weekly") + + def test_weekly_with_mmwr_anchor_constructs(self, simple_delay_pmf): + """aggregation='weekly' with MMWR_WEEK is valid.""" + process = self._make(simple_delay_pmf, aggregation="weekly", week=MMWR_WEEK) + assert process.aggregation == "weekly" + assert process.week == MMWR_WEEK + + def test_dow_effect_with_weekly_aggregation_raises(self, simple_delay_pmf): + """day_of_week_rv cannot be combined with aggregation='weekly'.""" + with pytest.raises(ValueError, match="day_of_week_rv cannot be combined"): + self._make( + simple_delay_pmf, + aggregation="weekly", + week=MMWR_WEEK, + day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), + ) + + def test_dow_effect_with_daily_aggregation_allowed(self, simple_delay_pmf): + """day_of_week_rv remains valid for aggregation='daily'.""" + process = self._make( + simple_delay_pmf, + day_of_week_rv=DeterministicVariable("dow", jnp.ones(7)), + ) + assert process.day_of_week_rv is not None + + def test_unknown_aggregation_raises(self, simple_delay_pmf): + """aggregation must be 'daily' or 'weekly'.""" + with pytest.raises(ValueError, match="aggregation must be one of"): + self._make(simple_delay_pmf, aggregation="monthly") + + def test_unknown_reporting_schedule_raises(self, simple_delay_pmf): + """reporting_schedule must be 'regular' or 'irregular'.""" + with pytest.raises(ValueError, match="reporting_schedule must be one of"): + self._make(simple_delay_pmf, reporting_schedule="sporadic") + + +# =================================================================== +# PopulationCounts with aggregation: validate_data +# =================================================================== + + +class TestPopulationCountsAggregationValidateData: + """validate_data branches for each (schedule, aggregation_period) combination.""" + + def test_weekly_regular_correct_length_passes(self, weekly_regular_counts): + """Weekly-regular obs of length n_periods passes when first_day_dow aligns.""" + obs = jnp.ones(4) * 10.0 + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=6 + ) + + def test_weekly_regular_wrong_length_raises(self, weekly_regular_counts): + """Weekly-regular obs with wrong length raises.""" + obs = jnp.ones(28) * 10.0 + with pytest.raises(ValueError, match="must equal n_periods"): + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=6 + ) + + def test_weekly_regular_missing_first_day_dow_raises(self, weekly_regular_counts): + """Weekly-regular with first_day_dow=None raises.""" + obs = jnp.ones(4) * 10.0 + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_regular_counts.validate_data(n_total=28, n_subpops=1, obs=obs) + + def test_weekly_regular_obs_none_passes(self, weekly_regular_counts): + """Weekly-regular with obs=None skips checks.""" + weekly_regular_counts.validate_data(n_total=28, n_subpops=1) + + def test_weekly_regular_honors_offset(self, weekly_regular_counts): + """Weekly-regular n_periods reflects front-trim offset.""" + obs = jnp.ones(3) * 10.0 + weekly_regular_counts.validate_data( + n_total=28, n_subpops=1, obs=obs, first_day_dow=0 + ) + + def test_weekly_irregular_aligned_times_pass(self, weekly_irregular_counts): + """Weekly-irregular with Saturdays at offset 0 passes.""" + period_end_times = jnp.array([6, 13, 20]) + obs = jnp.ones(3) * 10.0 + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_weekly_irregular_misaligned_times_raise(self, weekly_irregular_counts): + """Weekly-irregular with non-Saturday period_end_times raises.""" + period_end_times = jnp.array([6, 12, 20]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_weekly_irregular_missing_first_day_dow_raises( + self, weekly_irregular_counts + ): + """Weekly-irregular with first_day_dow=None raises.""" + period_end_times = jnp.array([6, 13, 20]) + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_irregular_counts.validate_data( + n_total=28, n_subpops=1, period_end_times=period_end_times + ) + + def test_weekly_irregular_obs_shape_mismatch_raises(self, weekly_irregular_counts): + """Weekly-irregular obs of wrong length raises.""" + period_end_times = jnp.array([6, 13, 20]) + obs = jnp.ones(2) * 10.0 + with pytest.raises(ValueError, match="must match"): + weekly_irregular_counts.validate_data( + n_total=28, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + first_day_dow=6, + ) + + def test_daily_irregular_passes(self, daily_irregular_counts): + """Daily-irregular validates via bounds only; alignment is trivial.""" + period_end_times = jnp.array([0, 5, 19]) + obs = jnp.ones(3) * 10.0 + daily_irregular_counts.validate_data( + n_total=20, + n_subpops=1, + obs=obs, + period_end_times=period_end_times, + ) + + def test_daily_irregular_out_of_bounds_raises(self, daily_irregular_counts): + """Daily-irregular out-of-bounds index raises.""" + period_end_times = jnp.array([0, 5, 25]) + with pytest.raises(ValueError, match="upper bound"): + daily_irregular_counts.validate_data( + n_total=20, n_subpops=1, period_end_times=period_end_times + ) + + def test_irregular_no_period_end_times_is_noop(self, weekly_irregular_counts): + """Irregular schedule with period_end_times=None returns without error.""" + weekly_irregular_counts.validate_data( + n_total=28, n_subpops=1, obs=None, period_end_times=None + ) + + +# =================================================================== +# PopulationCounts with aggregation: sample +# =================================================================== + + +class TestPopulationCountsAggregationSample: + """Sample-time behavior for the new aggregation paths.""" + + def test_weekly_regular_predicted_shape(self, weekly_regular_counts): + """Weekly-regular predicted has shape (n_periods,) equal to weekly sums.""" + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_counts.sample( + infections=infections, first_day_dow=6 + ) + assert result.predicted.shape == (4,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_regular_emits_predicted_daily_site(self, weekly_regular_counts): + """When aggregation_period > 1, the 'predicted_daily' deterministic site exists.""" + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_counts.sample(infections=infections, first_day_dow=6) + assert "hosp_predicted_daily" in trace + assert "hosp_predicted" in trace + assert trace["hosp_predicted_daily"]["value"].shape == (28,) + assert trace["hosp_predicted"]["value"].shape == (4,) + + def test_daily_backward_compat_no_predicted_daily_site(self, simple_delay_pmf): + """When aggregation_period == 1, 'predicted_daily' is not emitted.""" + process = PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + infections = jnp.ones(28) * 100.0 + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + process.sample(infections=infections) + assert "hosp_predicted_daily" not in trace + assert "hosp_predicted" in trace + assert trace["hosp_predicted"]["value"].shape == (28,) + + def test_weekly_regular_with_obs_runs(self, weekly_regular_counts): + """Weekly-regular sample accepts dense-with-NaN obs on the period grid.""" + infections = jnp.ones(28) * 100.0 + obs = jnp.array([7.0, jnp.nan, 7.0, 7.0]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_counts.sample( + infections=infections, obs=obs, first_day_dow=6 + ) + assert result.observed.shape == (4,) + assert result.predicted.shape == (4,) + + def test_weekly_irregular_period_indexing(self, weekly_irregular_counts): + """Weekly-irregular fancy-indexes the aggregated array at correct periods.""" + infections = jnp.ones(28) * 100.0 + period_end_times = jnp.array([6, 20]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_irregular_counts.sample( + infections=infections, + period_end_times=period_end_times, + first_day_dow=6, + ) + assert result.predicted.shape == (4,) + assert result.observed.shape == (2,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_irregular_missing_period_end_times_raises( + self, weekly_irregular_counts + ): + """Irregular schedule requires period_end_times at sample time.""" + infections = jnp.ones(28) * 100.0 + with pytest.raises(ValueError, match="period_end_times is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_irregular_counts.sample(infections=infections, first_day_dow=6) + + def test_daily_irregular_period_indexing(self, daily_irregular_counts): + """Daily-irregular fancy-indexes at the supplied daily indices directly.""" + infections = jnp.ones(30) * 100.0 + period_end_times = jnp.array([5, 10, 20]) + with numpyro.handlers.seed(rng_seed=42): + result = daily_irregular_counts.sample( + infections=infections, period_end_times=period_end_times + ) + assert result.predicted.shape == (30,) + assert result.observed.shape == (3,) + + def test_aggregate_helper_missing_first_day_dow_raises(self, weekly_regular_counts): + """_aggregate raises when aggregation == 'weekly' and first_day_dow is None.""" + predicted_daily = jnp.ones(28) + with pytest.raises( + ValueError, match="first_day_dow is required when aggregation == 'weekly'" + ): + weekly_regular_counts._aggregate(predicted_daily, first_day_dow=None) + + +# =================================================================== +# SubpopulationCounts with aggregation: validate_data +# =================================================================== + + +class TestSubpopulationCountsAggregationValidateData: + """validate_data branches for each (schedule, aggregation_period) combination.""" + + def test_daily_regular_valid_passes(self, simple_delay_pmf): + """Daily-regular dense 2D obs (n_total, n_observed_subpops) passes.""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + obs = jnp.ones((30, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + process.validate_data( + n_total=30, n_subpops=3, obs=obs, subpop_indices=subpop_indices + ) + + def test_weekly_regular_valid_passes(self, weekly_regular_subpop_counts): + """Weekly-regular dense 2D obs (n_periods, n_observed_subpops) passes.""" + obs = jnp.ones((4, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_regular_no_obs_is_noop(self, weekly_regular_subpop_counts): + """Regular schedule with obs=None returns without error.""" + weekly_regular_subpop_counts.validate_data(n_total=28, n_subpops=3, obs=None) + + def test_weekly_regular_wrong_n_periods_raises(self, weekly_regular_subpop_counts): + """Weekly-regular obs with wrong dim-0 length raises.""" + obs = jnp.ones((28, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match="must equal n_periods"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_regular_wrong_n_subpops_raises(self, weekly_regular_subpop_counts): + """Regular-schedule obs dim-1 must equal len(subpop_indices).""" + obs = jnp.ones((4, 3)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match=r"must equal len\(subpop_indices\)"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_regular_missing_first_day_dow_raises( + self, weekly_regular_subpop_counts + ): + """Weekly-regular without first_day_dow raises.""" + obs = jnp.ones((4, 2)) * 5.0 + subpop_indices = jnp.array([0, 2]) + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + subpop_indices=subpop_indices, + ) + + def test_regular_1d_obs_raises(self, weekly_regular_subpop_counts): + """Regular-schedule obs must be 2D.""" + obs = jnp.ones(4) * 5.0 + subpop_indices = jnp.array([0]) + with pytest.raises(ValueError, match="regular-schedule obs must be 2D"): + weekly_regular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_regular_bad_subpop_indices_raises(self, weekly_regular_subpop_counts): + """Regular-schedule out-of-bounds subpop_indices raises.""" + subpop_indices = jnp.array([0, 5]) + with pytest.raises(ValueError, match="upper bound"): + weekly_regular_subpop_counts.validate_data( + n_total=28, n_subpops=3, subpop_indices=subpop_indices + ) + + def test_daily_irregular_valid_passes(self, daily_irregular_subpop_counts): + """Daily-irregular with valid period_end_times and subpop_indices passes.""" + period_end_times = jnp.array([5, 10, 15, 20]) + subpop_indices = jnp.array([0, 1, 2, 0]) + obs = jnp.array([10.0, 20.0, 30.0, 15.0]) + daily_irregular_subpop_counts.validate_data( + n_total=30, + n_subpops=3, + obs=obs, + period_end_times=period_end_times, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_valid_passes( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Weekly-irregular with Saturdays at offset 0 passes.""" + subpop_indices = jnp.array([0, 1, 2]) + obs = jnp.array([10.0, 20.0, 30.0]) + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_misaligned_raises(self, weekly_irregular_subpop_counts): + """Weekly-irregular with non-Saturday period_end_times raises.""" + period_end_times = jnp.array([6, 12, 20]) + subpop_indices = jnp.array([0, 1, 2]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=period_end_times, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_weekly_irregular_missing_first_day_dow_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Weekly-irregular without first_day_dow raises.""" + with pytest.raises(ValueError, match="first_day_dow is required"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=mmwr_saturday_indices_first_three, + ) + + def test_irregular_obs_shape_mismatch_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Irregular-schedule obs length must match period_end_times.""" + obs = jnp.array([10.0, 20.0]) + with pytest.raises(ValueError, match="must match"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + obs=obs, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + ) + + def test_irregular_subpop_indices_shape_mismatch_raises( + self, weekly_irregular_subpop_counts, mmwr_saturday_indices_first_three + ): + """Irregular-schedule subpop_indices length must match period_end_times.""" + subpop_indices = jnp.array([0, 1]) + with pytest.raises(ValueError, match="must match"): + weekly_irregular_subpop_counts.validate_data( + n_total=28, + n_subpops=3, + period_end_times=mmwr_saturday_indices_first_three, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + +# =================================================================== +# SubpopulationCounts with aggregation: sample +# =================================================================== + + +class TestSubpopulationCountsAggregationSample: + """Sample-time behavior for the new aggregation paths.""" + + def test_daily_regular_shape(self, simple_delay_pmf, subpop_infections_30d): + """Daily-regular sample returns predicted of shape (n_total, n_subpops).""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = process.sample( + infections=subpop_infections_30d, subpop_indices=subpop_indices + ) + assert result.predicted.shape == (30, 3) + assert result.observed.shape == (30, 2) + + def test_weekly_regular_predicted_shape_and_values( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """Weekly-regular aggregates (n_total, n_subpops) to (n_periods, n_subpops).""" + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + assert result.observed.shape == (4, 2) + + def test_weekly_regular_emits_predicted_daily_site( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """aggregation_period > 1 emits a 'predicted_daily' deterministic site.""" + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert "ed_predicted_daily" in trace + assert "ed_predicted" in trace + assert trace["ed_predicted_daily"]["value"].shape == (28, 3) + assert trace["ed_predicted"]["value"].shape == (4, 3) + + def test_daily_backward_compat_no_predicted_daily_site( + self, simple_delay_pmf, subpop_infections_28d + ): + """aggregation_period == 1 emits only 'predicted', not 'predicted_daily'.""" + process = SubpopulationCounts( + name="ed", + ascertainment_rate_rv=DeterministicVariable("iedr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", simple_delay_pmf), + noise=PoissonNoise(), + ) + subpop_indices = jnp.array([0, 1, 2]) + with numpyro.handlers.trace() as trace: + with numpyro.handlers.seed(rng_seed=42): + process.sample( + infections=subpop_infections_28d, subpop_indices=subpop_indices + ) + assert "ed_predicted_daily" not in trace + assert "ed_predicted" in trace + assert trace["ed_predicted"]["value"].shape == (28, 3) + + def test_missing_subpop_indices_raises( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """subpop_indices is required in every sample call.""" + with pytest.raises(ValueError, match="subpop_indices is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, first_day_dow=6 + ) + + def test_weekly_irregular_period_indexing( + self, weekly_irregular_subpop_counts, subpop_infections_28d + ): + """Weekly-irregular fancy-indexes the aggregated array at (period, subpop).""" + period_end_times = jnp.array([6, 20]) + subpop_indices = jnp.array([0, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_irregular_subpop_counts.sample( + infections=subpop_infections_28d, + period_end_times=period_end_times, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert result.observed.shape == (2,) + assert jnp.allclose(result.predicted, 7.0, rtol=1e-5) + + def test_weekly_irregular_missing_period_end_times_raises( + self, weekly_irregular_subpop_counts, subpop_infections_28d + ): + """Irregular schedule requires period_end_times at sample time.""" + subpop_indices = jnp.array([0, 1, 2]) + with pytest.raises(ValueError, match="period_end_times is required"): + with numpyro.handlers.seed(rng_seed=42): + weekly_irregular_subpop_counts.sample( + infections=subpop_infections_28d, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + + def test_daily_irregular_fancy_indexing( + self, daily_irregular_subpop_counts, subpop_infections_30d + ): + """Daily-irregular indexes predicted at (period_end_times, subpop_indices).""" + period_end_times = jnp.array([5, 10, 20]) + subpop_indices = jnp.array([0, 1, 2]) + with numpyro.handlers.seed(rng_seed=42): + result = daily_irregular_subpop_counts.sample( + infections=subpop_infections_30d, + period_end_times=period_end_times, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (30, 3) + assert result.observed.shape == (3,) + + def test_weekly_regular_with_obs_conditions( + self, weekly_regular_subpop_counts, subpop_infections_28d + ): + """Weekly-regular sample conditions on 2D obs with NaN-padding for unobserved periods.""" + subpop_indices = jnp.array([0, 2]) + obs = jnp.array( + [ + [7.0, 7.0], + [jnp.nan, jnp.nan], + [7.0, 7.0], + [7.0, 7.0], + ] + ) + with numpyro.handlers.seed(rng_seed=42): + result = weekly_regular_subpop_counts.sample( + infections=subpop_infections_28d, + obs=obs, + first_day_dow=6, + subpop_indices=subpop_indices, + ) + assert result.predicted.shape == (4, 3) + assert result.observed.shape == (4, 2) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_observation_validation.py b/test/test_observation_validation.py index 3632f04a..a52b6724 100644 --- a/test/test_observation_validation.py +++ b/test/test_observation_validation.py @@ -19,6 +19,7 @@ SubpopulationCounts, ) from pyrenew.randomvariable import DistributionalVariable, VectorizedVariable +from pyrenew.time import MMWR_WEEK, WeekCycle, daily_to_weekly # --------------------------------------------------------------------------- # Helpers – minimal concrete subclass of MeasurementObservation for testing @@ -95,6 +96,7 @@ def subpop_proc(): ascertainment_rate_rv=DeterministicVariable("ihr", 0.02), delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.4, 0.4, 0.2])), noise=PoissonNoise(), + reporting_schedule="irregular", ) @@ -188,6 +190,32 @@ def test_non_contiguous_valid_indices(self, counts_proc): param_name="test_idx", ) + def test_empty_array_passes(self, counts_proc): + """Empty index array has no values to bounds-check; should not raise.""" + counts_proc._validate_index_array( + jnp.array([], dtype=jnp.int32), + upper_bound=10, + param_name="test_idx", + ) + + def test_scalar_raises(self, counts_proc): + """0-D scalar index array should raise ValueError.""" + with pytest.raises(ValueError, match="must be 1D"): + counts_proc._validate_index_array( + jnp.asarray(0), + upper_bound=5, + param_name="test_idx", + ) + + def test_2d_raises(self, counts_proc): + """2D index array should raise ValueError.""" + with pytest.raises(ValueError, match="must be 1D"): + counts_proc._validate_index_array( + jnp.array([[0, 1], [2, 3]]), + upper_bound=5, + param_name="test_idx", + ) + # =================================================================== # _validate_times @@ -250,44 +278,56 @@ def test_subpop_index_above_n_subpops_raises(self, counts_proc): # =================================================================== -# _validate_obs_times_shape +# _validate_shapes_match # =================================================================== -class TestValidateObsTimesShape: - """Tests for BaseObservationProcess._validate_obs_times_shape.""" +class TestValidateShapesMatch: + """Tests for BaseObservationProcess._validate_shapes_match.""" def test_matching_1d_shapes(self, counts_proc): - """Obs and times with matching 1D shapes should not raise.""" - obs = jnp.array([1.0, 2.0, 3.0]) - times = jnp.array([0, 5, 10]) - counts_proc._validate_obs_times_shape(obs, times) + """Two arrays with matching 1D shapes should not raise.""" + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([0, 5, 10]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_mismatched_lengths_raises(self, counts_proc): - """Obs and times with different lengths should raise ValueError.""" - obs = jnp.array([1.0, 2.0, 3.0]) - times = jnp.array([0, 5]) + """Two arrays with different lengths should raise ValueError.""" + a = jnp.array([1.0, 2.0, 3.0]) + b = jnp.array([0, 5]) with pytest.raises(ValueError, match="must match times shape"): - counts_proc._validate_obs_times_shape(obs, times) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_empty_arrays_match(self, counts_proc): """Two empty arrays should not raise.""" - obs = jnp.array([]) - times = jnp.array([]) - counts_proc._validate_obs_times_shape(obs, times) + a = jnp.array([]) + b = jnp.array([]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_scalar_arrays_match(self, counts_proc): """Single-element arrays should not raise.""" - obs = jnp.array([5.0]) - times = jnp.array([0]) - counts_proc._validate_obs_times_shape(obs, times) + a = jnp.array([5.0]) + b = jnp.array([0]) + counts_proc._validate_shapes_match(a, b, "obs", "times") def test_error_includes_both_shapes(self, counts_proc): """Error message should report both shapes.""" - obs = jnp.array([1.0, 2.0]) - times = jnp.array([0, 1, 2]) + a = jnp.array([1.0, 2.0]) + b = jnp.array([0, 1, 2]) with pytest.raises(ValueError, match=r"\(2,\).*\(3,\)"): - counts_proc._validate_obs_times_shape(obs, times) + counts_proc._validate_shapes_match(a, b, "obs", "times") + + def test_error_uses_provided_names(self, counts_proc): + """Error message should use the caller-supplied parameter names.""" + a = jnp.array([1.0, 2.0]) + b = jnp.array([0, 1, 2]) + with pytest.raises( + ValueError, + match=r"subpop_indices shape .* must match period_end_times shape", + ): + counts_proc._validate_shapes_match( + a, b, "subpop_indices", "period_end_times" + ) # =================================================================== @@ -326,6 +366,12 @@ def test_error_includes_lengths(self, counts_proc): with pytest.raises(ValueError, match="15.*30"): counts_proc._validate_obs_dense(obs, n_total=30) + def test_2d_obs_raises(self, counts_proc): + """2D obs (e.g., shape (n_total, 1)) should raise ValueError.""" + obs = jnp.ones((30, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + counts_proc._validate_obs_dense(obs, n_total=30) + # =================================================================== # PopulationCounts.validate_data() @@ -355,6 +401,36 @@ def test_wrong_length_obs_raises(self, counts_proc): with pytest.raises(ValueError, match="must equal n_total"): counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + def test_2d_obs_daily_raises(self, counts_proc): + """Daily regular-schedule obs must be 1D; 2D (n_total, 1) should raise.""" + obs = jnp.ones((30, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + counts_proc.validate_data(n_total=30, n_subpops=1, obs=obs) + + def test_2d_obs_weekly_raises(self): + """Weekly regular-schedule obs must be 1D; 2D (n_periods, 1) should raise.""" + proc = PopulationCounts( + name="hosp", + ascertainment_rate_rv=DeterministicVariable("ihr", 0.01), + delay_distribution_rv=DeterministicPMF("delay", jnp.array([0.3, 0.5, 0.2])), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="regular", + week=MMWR_WEEK, + ) + n_total = 35 + first_day_dow = 6 + offset = proc._compute_period_offset(first_day_dow, proc.week) + n_periods = (n_total - offset) // proc.aggregation_period + obs = jnp.ones((n_periods, 1)) + with pytest.raises(ValueError, match="obs must be 1D"): + proc.validate_data( + n_total=n_total, + n_subpops=1, + obs=obs, + first_day_dow=first_day_dow, + ) + def test_extra_kwargs_ignored(self, counts_proc): """validate_data should ignore extra keyword arguments.""" counts_proc.validate_data( @@ -375,29 +451,33 @@ def test_all_none_passes(self, subpop_proc): subpop_proc.validate_data(n_total=30, n_subpops=3) def test_valid_data_passes(self, subpop_proc): - """validate_data with valid times, subpop_indices, obs should not raise.""" - times = jnp.array([5, 10, 15, 20]) + """validate_data with valid period_end_times, subpop_indices, obs should not raise.""" + period_end_times = jnp.array([5, 10, 15, 20]) subpop_indices = jnp.array([0, 1, 2, 0]) obs = jnp.array([10.0, 20.0, 30.0, 15.0]) subpop_proc.validate_data( n_total=30, n_subpops=3, - times=times, + period_end_times=period_end_times, subpop_indices=subpop_indices, obs=obs, ) def test_invalid_times_raises(self, subpop_proc): - """validate_data with out-of-bounds times should raise.""" - times = jnp.array([5, 30]) # 30 == n_total, out of bounds + """validate_data with out-of-bounds period_end_times should raise.""" + period_end_times = jnp.array([5, 30]) # 30 == n_total, out of bounds with pytest.raises(ValueError, match="upper bound"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_negative_times_raises(self, subpop_proc): - """validate_data with negative times should raise.""" - times = jnp.array([-1, 5]) + """validate_data with negative period_end_times should raise.""" + period_end_times = jnp.array([-1, 5]) with pytest.raises(ValueError, match="cannot be negative"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_invalid_subpop_indices_raises(self, subpop_proc): """validate_data with out-of-bounds subpop_indices should raise.""" @@ -416,22 +496,31 @@ def test_negative_subpop_indices_raises(self, subpop_proc): ) def test_mismatched_obs_times_raises(self, subpop_proc): - """validate_data with obs/times shape mismatch should raise.""" - times = jnp.array([5, 10, 15]) + """validate_data with obs/period_end_times shape mismatch should raise.""" + period_end_times = jnp.array([5, 10, 15]) obs = jnp.array([1.0, 2.0]) # length 2 != length 3 - with pytest.raises(ValueError, match="must match times shape"): - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times, obs=obs) + with pytest.raises( + ValueError, match=r"obs shape .* must match period_end_times shape" + ): + subpop_proc.validate_data( + n_total=30, + n_subpops=3, + period_end_times=period_end_times, + obs=obs, + ) - def test_obs_without_times_skips_shape_check(self, subpop_proc): - """validate_data with obs but no times should not check shapes.""" + def test_obs_without_times_raises(self, subpop_proc): + """validate_data with obs but no period_end_times should raise ValueError.""" obs = jnp.array([1.0, 2.0]) - # times is None, so shape check is skipped - subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) + with pytest.raises(ValueError, match="period_end_times is required"): + subpop_proc.validate_data(n_total=30, n_subpops=3, obs=obs) def test_times_without_obs_skips_shape_check(self, subpop_proc): - """validate_data with times but no obs should validate times only.""" - times = jnp.array([5, 10, 15]) - subpop_proc.validate_data(n_total=30, n_subpops=3, times=times) + """validate_data with period_end_times but no obs should validate indices only.""" + period_end_times = jnp.array([5, 10, 15]) + subpop_proc.validate_data( + n_total=30, n_subpops=3, period_end_times=period_end_times + ) def test_non_contiguous_subpop_indices_valid(self, subpop_proc): """validate_data with non-contiguous but valid subpop_indices passes.""" @@ -440,6 +529,13 @@ def test_non_contiguous_subpop_indices_valid(self, subpop_proc): n_total=30, n_subpops=3, subpop_indices=subpop_indices ) + def test_scalar_subpop_indices_raises(self, subpop_proc): + """Scalar (0-D) subpop_indices should raise ValueError, not IndexError.""" + with pytest.raises(ValueError, match="must be 1D"): + subpop_proc.validate_data( + n_total=30, n_subpops=3, subpop_indices=jnp.asarray(0) + ) + # =================================================================== # MeasurementObservation.validate_data() @@ -535,7 +631,7 @@ def test_mismatched_obs_times_raises(self, measurements_proc): """validate_data with obs/times shape mismatch should raise.""" times = jnp.array([5, 10, 15]) obs = jnp.array([1.0, 2.0]) - with pytest.raises(ValueError, match="must match times shape"): + with pytest.raises(ValueError, match=r"obs shape .* must match times shape"): measurements_proc.validate_data( n_total=30, n_subpops=3, times=times, obs=obs ) @@ -556,3 +652,194 @@ def test_non_contiguous_subpop_indices_valid(self, measurements_proc): def test_extra_kwargs_ignored(self, measurements_proc): """validate_data should ignore extra keyword arguments.""" measurements_proc.validate_data(n_total=30, n_subpops=3, foo="bar") + + +# =================================================================== +# _validate_week +# =================================================================== + + +class TestValidateWeek: + """Tests for BaseObservationProcess._validate_week.""" + + def test_daily_with_no_week_passes(self, counts_proc): + """aggregation='daily' with week=None should not raise.""" + counts_proc._validate_week("daily", None) + + def test_daily_ignores_week(self, counts_proc): + """aggregation='daily' ignores any supplied WeekCycle.""" + counts_proc._validate_week("daily", MMWR_WEEK) + + def test_weekly_with_week_passes(self, counts_proc): + """aggregation='weekly' with any WeekCycle should not raise.""" + counts_proc._validate_week("weekly", MMWR_WEEK) + counts_proc._validate_week("weekly", WeekCycle(start_dow=0)) + + def test_weekly_missing_week_raises(self, counts_proc): + """aggregation='weekly' with week=None should raise.""" + with pytest.raises(ValueError, match="week is required"): + counts_proc._validate_week("weekly", None) + + @pytest.mark.parametrize("bad", ["monthly", "biweekly", "Weekly", ""]) + def test_unknown_aggregation_raises(self, counts_proc, bad): + """Unrecognized aggregation strings should raise.""" + with pytest.raises(ValueError, match="aggregation must be one of"): + counts_proc._validate_week(bad, MMWR_WEEK) + + +# =================================================================== +# _compute_period_offset +# =================================================================== + + +class TestComputePeriodOffset: + """Tests for BaseObservationProcess._compute_period_offset.""" + + def test_daily_returns_zero(self, counts_proc): + """week=None always returns 0 regardless of first_day_dow.""" + assert counts_proc._compute_period_offset(None, None) == 0 + assert counts_proc._compute_period_offset(0, None) == 0 + assert counts_proc._compute_period_offset(6, None) == 0 + + def test_mmwr_aligned_start(self, counts_proc): + """Daily axis starting Sunday under MMWR_WEEK => offset 0.""" + assert counts_proc._compute_period_offset(6, MMWR_WEEK) == 0 + + def test_monday_start_mmwr(self, counts_proc): + """Daily axis starting Monday under MMWR_WEEK => offset 6.""" + assert counts_proc._compute_period_offset(0, MMWR_WEEK) == 6 + + def test_saturday_start_mmwr(self, counts_proc): + """Daily axis starting Saturday under MMWR_WEEK => offset 1.""" + assert counts_proc._compute_period_offset(5, MMWR_WEEK) == 1 + + def test_iso_week_alignment(self, counts_proc): + """Daily axis starting Thursday under ISO week (Mon start) => offset 4.""" + assert counts_proc._compute_period_offset(3, WeekCycle(start_dow=0)) == 4 + + def test_offset_always_in_range(self, counts_proc): + """Offset is always in [0, 7) for any valid (first_day_dow, week).""" + for first in range(7): + for start in range(7): + offset = counts_proc._compute_period_offset( + first, WeekCycle(start_dow=start) + ) + assert 0 <= offset < 7 + + def test_missing_first_day_dow_raises(self, counts_proc): + """week set with first_day_dow=None should raise.""" + with pytest.raises(ValueError, match="first_day_dow is required"): + counts_proc._compute_period_offset(None, MMWR_WEEK) + + def test_offset_agrees_with_daily_to_weekly(self, counts_proc): + """ + Offset from _compute_period_offset selects the same leading + days that daily_to_weekly trims internally for the first + complete period, for every (first_day_dow, week) pair. + """ + daily = jnp.arange(21.0) + for first in range(7): + for start in range(7): + week = WeekCycle(start_dow=start) + offset = counts_proc._compute_period_offset(first, week) + weekly = daily_to_weekly( + daily, + input_data_first_dow=first, + week_start_dow=start, + ) + expected_first_week = float(jnp.sum(daily[offset : offset + 7])) + assert float(weekly[0]) == expected_first_week + + +# =================================================================== +# _validate_period_end_times +# =================================================================== + + +class TestValidatePeriodEndTimes: + """Tests for BaseObservationProcess._validate_period_end_times.""" + + def test_p7_aligned_times_pass(self, counts_proc): + """P=7 with Saturdays at offset 0 should not raise.""" + times = jnp.array([6, 13, 20]) + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_p7_aligned_nonzero_offset_pass(self, counts_proc): + """P=7 with nonzero offset shifts the boundary days.""" + times = jnp.array([12, 19, 26]) + counts_proc._validate_period_end_times( + times, n_total=30, offset=6, aggregation_period=7 + ) + + def test_p1_any_in_bounds_passes(self, counts_proc): + """P=1: alignment is trivial; any in-bounds index passes.""" + times = jnp.array([0, 3, 7, 19]) + counts_proc._validate_period_end_times( + times, n_total=20, offset=0, aggregation_period=1 + ) + + def test_misaligned_time_raises(self, counts_proc): + """P=7 with a non-boundary time (past the first complete period) should raise.""" + times = jnp.array([7]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_partial_misalignment_raises(self, counts_proc): + """Any single misaligned entry should raise.""" + times = jnp.array([6, 12, 20]) + with pytest.raises(ValueError, match="period_end_times must lie on"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_negative_time_raises(self, counts_proc): + """Negative period_end_times should raise via bounds check.""" + times = jnp.array([-1, 6]) + with pytest.raises(ValueError, match="cannot be negative"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_time_at_n_total_raises(self, counts_proc): + """period_end_times == n_total should raise via bounds check.""" + times = jnp.array([21]) + with pytest.raises(ValueError, match="upper bound"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_error_reports_offset_and_period(self, counts_proc): + """Alignment error message should include offset and aggregation_period.""" + times = jnp.array([7]) + with pytest.raises(ValueError, match=r"offset=0.*aggregation_period=7"): + counts_proc._validate_period_end_times( + times, n_total=21, offset=0, aggregation_period=7 + ) + + def test_time_before_first_complete_period_raises(self, counts_proc): + """ + An entry before the first complete period's final day must raise. + + With ``offset=6, aggregation_period=7``, ``t=5`` satisfies the + modulo boundary check ``(5 - 6) % 7 == 6`` (Python negative + modulo), but maps to ``period_idx = -1`` under fancy indexing, + which JAX silently wraps. The lower-bound check must reject it. + """ + times = jnp.array([5]) + with pytest.raises( + ValueError, match="do not correspond to a complete aggregation period" + ): + counts_proc._validate_period_end_times( + times, n_total=20, offset=6, aggregation_period=7 + ) + + def test_time_at_first_complete_period_end_passes(self, counts_proc): + """The first complete period's final day is the lower-bound edge case.""" + times = jnp.array([12]) + counts_proc._validate_period_end_times( + times, n_total=20, offset=6, aggregation_period=7 + ) diff --git a/test/test_population_infections.py b/test/test_population_infections.py index 261ac8f7..cbe6febc 100644 --- a/test/test_population_infections.py +++ b/test/test_population_infections.py @@ -7,8 +7,9 @@ import pytest from pyrenew.deterministic import DeterministicVariable -from pyrenew.latent import RandomWalk +from pyrenew.latent import AR1, RandomWalk, StepwiseTemporalProcess from pyrenew.latent.population_infections import PopulationInfections +from pyrenew.time import MMWR_WEEK class TestPopulationInfectionsSample: @@ -125,6 +126,49 @@ def test_custom_name_prefix(self, gen_int_rv): assert "my_infections::rt_single" in trace + def test_first_day_dow_reaches_calendar_aligned_rt_process(self, gen_int_rv): + """Calendar-aligned stepwise Rt receives model-axis day of week.""" + process = PopulationInfections( + name="population", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + n_initialization_points=7, + ) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + process.sample(n_days_post_init=10, first_day_dow=3) + + log_rt = trace["population::log_rt_single"]["value"] + coarse = trace["log_rt_single_coarse"]["value"] + + assert log_rt.shape == (17, 1) + assert coarse.shape == (3, 1) + assert jnp.allclose(log_rt[:3], log_rt[0]) + assert jnp.allclose(log_rt[3:10], log_rt[3]) + assert jnp.allclose(log_rt[10:17], log_rt[10]) + + def test_wrong_rt_shape_raises(self, gen_int_rv, wrong_shape_temporal_process_cls): + """Temporal processes must return daily-length single-process Rt.""" + process = PopulationInfections( + name="population", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("log_rt_time_0", 0.0), + single_rt_process=wrong_shape_temporal_process_cls(jnp.zeros((16, 2))), + n_initialization_points=7, + ) + + with pytest.raises(ValueError, match="single_rt_process must return shape"): + process.sample(n_days_post_init=10) + class TestPopulationInfectionsValidation: """Test validation of inputs.""" diff --git a/test/test_pyrenew_builder.py b/test/test_pyrenew_builder.py index ff60a309..d2315530 100644 --- a/test/test_pyrenew_builder.py +++ b/test/test_pyrenew_builder.py @@ -2,22 +2,58 @@ Tests for PyrenewBuilder and MultiSignalModel. """ +from datetime import date, timedelta + import jax.numpy as jnp +import numpyro import pytest from pyrenew.deterministic import DeterministicPMF, DeterministicVariable -from pyrenew.latent import RandomWalk, SubpopulationInfections +from pyrenew.latent import ( + AR1, + PopulationInfections, + RandomWalk, + StepwiseTemporalProcess, + SubpopulationInfections, +) from pyrenew.model import MultiSignalModel, PyrenewBuilder from pyrenew.observation import ( NegativeBinomialNoise, + PoissonNoise, PopulationCounts, SubpopulationCounts, ) +from pyrenew.time import MMWR_WEEK, WeekCycle # Standard population structure for tests (3 subpopulations) SUBPOP_FRACTIONS = jnp.array([0.3, 0.25, 0.45]) +def _obs_date_for_dow(target_first_day_dow: int, n_init: int) -> date: + """ + Return an ``obs_start_date`` whose axis-origin day-of-week matches. + + Given a desired ``first_day_dow`` (day-of-week of element 0 of the + padded axis) and the model's ``n_init``, pick a concrete date for + the first observation day such that + ``(obs_start_date.weekday() - n_init) % 7 == target_first_day_dow``. + + Parameters + ---------- + target_first_day_dow + Desired axis-origin day-of-week in ``{0, ..., 6}``. + n_init + Initialization-period length in days. + + Returns + ------- + datetime.date + A date with the required day-of-week, drawn from January 2024. + """ + obs_dow = (target_first_day_dow + n_init) % 7 + return date(2024, 1, 1) + timedelta(days=obs_dow) + + @pytest.fixture def simple_builder(): """ @@ -93,6 +129,7 @@ def validation_builder(): ascertainment_rate_rv=DeterministicVariable("ihr_subpop", 0.01), delay_distribution_rv=delay, noise=NegativeBinomialNoise(DeterministicVariable("conc_subpop", 10.0)), + reporting_schedule="irregular", ) ) @@ -267,6 +304,120 @@ def test_prior_predictive_multi_signal(self, simple_builder): # All prior predictive infections should be positive assert jnp.all(prior_samples["latent_infections"] > 0) + def test_first_day_dow_reaches_calendar_aligned_latent_process(self): + """MultiSignalModel forwards model-axis day of week to the latent process.""" + latent = PopulationInfections( + name="PopulationInfections", + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + n_initialization_points=3, + ) + model = MultiSignalModel(latent, {"ed": _daily_ed_counts()}) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + model.sample( + n_days_post_init=10, + population_size=1_000_000, + obs_start_date=_obs_date_for_dow(target_first_day_dow=3, n_init=3), + ed={"obs": None}, + ) + + log_rt = trace["PopulationInfections::log_rt_single"]["value"] + coarse = trace["log_rt_single_coarse"]["value"] + + assert log_rt.shape == (model.latent.n_initialization_points + 10, 1) + assert coarse.shape[0] < log_rt.shape[0] + assert jnp.allclose(log_rt[:3], log_rt[0]) + assert jnp.allclose(log_rt[3:10], log_rt[3]) + + def test_missing_obs_start_date_for_calendar_aligned_latent_process_raises(self): + """Calendar-aligned latent temporal processes trigger the model-entry anchor check.""" + latent = PopulationInfections( + name="PopulationInfections", + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + n_initialization_points=3, + ) + model = MultiSignalModel(latent, {"ed": _daily_ed_counts()}) + + with numpyro.handlers.seed(rng_seed=42): + with pytest.raises( + ValueError, match="obs_start_date is required.*calendar_week" + ): + model.sample( + n_days_post_init=10, + population_size=1_000_000, + ed={"obs": None}, + ) + + def test_builder_mixed_cadence_weekly_rt_samples(self): + """ + Mixed daily/weekly observations can use weekly-parameterized Rt. + + The latent temporal process records a coarse Rt trajectory, the latent + process records daily Rt values, ED remains on the daily likelihood + scale, and hospital observations are scored on the weekly likelihood + scale. + """ + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + model = builder.build() + n_days_post_init = 28 + n_total = model.latent.n_initialization_points + n_days_post_init + obs_start_date = _obs_date_for_dow( + target_first_day_dow=6, + n_init=model.latent.n_initialization_points, + ) + hospital_obs = jnp.array([jnp.nan, 5.0, 7.0, 6.0], dtype=float) + ed_obs = jnp.concatenate( + [ + jnp.full(model.latent.n_initialization_points, jnp.nan), + jnp.ones(n_days_post_init) * 3.0, + ] + ) + + with numpyro.handlers.seed(rng_seed=42): + with numpyro.handlers.trace() as trace: + model.sample( + n_days_post_init=n_days_post_init, + population_size=1_000_000, + obs_start_date=obs_start_date, + hospital={"obs": hospital_obs}, + ed={"obs": ed_obs}, + ) + + assert trace["log_rt_single_coarse"]["value"].shape == (5, 1) + assert trace["PopulationInfections::log_rt_single"]["value"].shape == ( + n_total, + 1, + ) + assert trace["ed_obs"]["value"].shape == (n_total,) + assert trace["hospital_obs"]["value"].shape == hospital_obs.shape + assert trace["hospital_predicted_daily"]["value"].shape == (n_total,) + assert trace["hospital_predicted"]["value"].shape == hospital_obs.shape + class TestMultiSignalModelValidation: """Test data validation.""" @@ -284,7 +435,7 @@ def test_validate_data_accepts_valid_data(self, validation_builder): }, hospital_subpop={ "obs": jnp.array([10, 20]), - "times": jnp.array([5, 10]), + "period_end_times": jnp.array([5, 10]), }, ) @@ -299,7 +450,7 @@ def test_validate_data_rejects_out_of_bounds_times(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10]), - "times": jnp.array([n_total + 10]), + "period_end_times": jnp.array([n_total + 10]), }, ) @@ -313,7 +464,7 @@ def test_validate_data_rejects_negative_times(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10]), - "times": jnp.array([-1]), + "period_end_times": jnp.array([-1]), }, ) @@ -327,7 +478,7 @@ def test_validate_data_rejects_unknown_observation(self, validation_builder): subpop_fractions=SUBPOP_FRACTIONS, unknown_obs={ "obs": jnp.array([10]), - "times": jnp.array([5]), + "period_end_times": jnp.array([5]), }, ) @@ -343,7 +494,7 @@ def test_validate_data_rejects_mismatched_obs_times_length( subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "obs": jnp.array([10, 20, 30]), # 3 elements - "times": jnp.array([5, 10]), # 2 elements + "period_end_times": jnp.array([5, 10]), # 2 elements }, ) @@ -363,7 +514,7 @@ def test_validate_data_rejects_negative_subpop_indices(self, validation_builder) subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "subpop_indices": jnp.array([-1, 0, 1]), - "times": jnp.array([5, 6, 7]), + "period_end_times": jnp.array([5, 6, 7]), }, ) @@ -380,7 +531,7 @@ def test_validate_data_rejects_out_of_bounds_subpop_indices( subpop_fractions=SUBPOP_FRACTIONS, hospital_subpop={ "subpop_indices": jnp.array([0, 1, 5]), # 5 >= 3 - "times": jnp.array([5, 6, 7]), + "period_end_times": jnp.array([5, 6, 7]), }, ) @@ -426,10 +577,16 @@ def test_pad_observations_prepends_nans(self, simple_builder): (6, (6 - 3) % 7), ], ) - def test_compute_first_day_dow(self, simple_builder, obs_start_dow, expected): - """Test that compute_first_day_dow offsets by n_initialization_points.""" + def test_resolve_first_day_dow(self, simple_builder, obs_start_dow, expected): + """_resolve_first_day_dow offsets by n_initialization_points.""" + model = simple_builder.build() + obs_date = date(2024, 1, 1) + timedelta(days=obs_start_dow) + assert model._resolve_first_day_dow(obs_date) == expected + + def test_resolve_first_day_dow_none_passthrough(self, simple_builder): + """_resolve_first_day_dow returns None when obs_start_date is None.""" model = simple_builder.build() - assert model.compute_first_day_dow(obs_start_dow) == expected + assert model._resolve_first_day_dow(None) is None def test_shift_times_adds_offset(self, simple_builder): """Test that shift_times shifts by n_initialization_points.""" @@ -442,5 +599,273 @@ def test_shift_times_adds_offset(self, simple_builder): assert jnp.array_equal(shifted, times + n_init) +def _coherence_builder( + *, + single_rt_process, + observations, +): + """ + Build a configured PyrenewBuilder with a PopulationInfections latent and + a supplied temporal process, plus the given observation instances. + + Returns + ------- + PyrenewBuilder + Builder configured but not yet built. + """ + builder = PyrenewBuilder() + builder.configure_latent( + PopulationInfections, + gen_int_rv=DeterministicPMF("gen_int", jnp.array([0.2, 0.5, 0.3])), + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + single_rt_process=single_rt_process, + ) + for obs in observations: + builder.add_observation(obs) + return builder + + +def _weekly_hosp_counts(name="hospital", week=MMWR_WEEK): + """ + Build a weekly-aggregated PopulationCounts observation with PoissonNoise. + + Returns + ------- + PopulationCounts + Weekly-regular observation anchored to the specified ``WeekCycle``. + """ + return PopulationCounts( + name=name, + ascertainment_rate_rv=DeterministicVariable(f"{name}_ihr", 0.01), + delay_distribution_rv=DeterministicPMF(f"{name}_delay", jnp.array([1.0])), + noise=PoissonNoise(), + aggregation="weekly", + reporting_schedule="regular", + week=week, + ) + + +def _daily_ed_counts(name="ed"): + """ + Build a daily PopulationCounts observation with PoissonNoise. + + Returns + ------- + PopulationCounts + Daily-regular observation with no aggregation. + """ + return PopulationCounts( + name=name, + ascertainment_rate_rv=DeterministicVariable(f"{name}_ihr", 0.01), + delay_distribution_rv=DeterministicPMF(f"{name}_delay", jnp.array([1.0])), + noise=PoissonNoise(), + ) + + +class TestBuilderConfigurations: + """PyrenewBuilder.build() accepts varied R(t) and observation cadences.""" + + def test_daily_rt_with_daily_observation_passes(self): + """step_size=1 and P=1: valid.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_weekly_rt_with_weekly_observation_passes(self): + """step_size=7 and P=7: valid when weekly is the only obs cadence.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_daily_rt_with_mixed_observations_passes(self): + """step_size=1 with mixed P=1 + P=7: valid (R(t) at finest cadence).""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_weekly_rt_with_daily_observation_passes(self): + """Coarse Rt parameter cadence is allowed with daily observations.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts(), _daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_mismatched_weekly_week_passes(self): + """Weekly observations can use different WeekCycles; each aggregates independently.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[ + _weekly_hosp_counts(name="hospital", week=MMWR_WEEK), + _weekly_hosp_counts(name="other", week=WeekCycle(start_dow=0)), + ], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_matching_weekly_week_passes(self): + """Two weekly observations sharing a WeekCycle build normally.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[ + _weekly_hosp_counts(name="hospital", week=MMWR_WEEK), + _weekly_hosp_counts(name="other", week=MMWR_WEEK), + ], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_arbitrary_step_size_with_weekly_observation_passes(self): + """Parameter cadence need not match observation aggregation period.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=2 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_calendar_week_alignment_matches_weekly_week_passes(self): + """Sunday-start weeks pair with Saturday-ending weekly observations.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_calendar_week_alignment_mismatches_weekly_week_passes(self): + """A weekly Rt anchor can differ from a weekly observation anchor.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=WeekCycle(start_dow=0), + ), + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_calendar_week_alignment_with_only_daily_observations_passes(self): + """Calendar-week-aligned R(t) pairs with daily observations.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + observations=[_daily_ed_counts()], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + def test_model_index_alignment_with_weekly_observation_passes(self): + """Model-index-aligned R(t) pairs with weekly observations.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts(week=MMWR_WEEK)], + ) + model = builder.build() + assert isinstance(model, MultiSignalModel) + + +class TestMultiSignalValidateDataAnchor: + """MultiSignalModel.validate_data sample-time anchor check for obs_start_date.""" + + def test_missing_obs_start_date_for_weekly_obs_raises(self): + """An observation with aggregation='weekly' must have obs_start_date supplied.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + with pytest.raises(ValueError, match="obs_start_date is required"): + model.validate_data( + n_days_post_init=28, + hospital={"obs": jnp.ones(4) * 5.0}, + ) + + def test_obs_start_date_supplied_for_weekly_obs_passes(self): + """Supplying obs_start_date satisfies the anchor check.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ), + observations=[_weekly_hosp_counts()], + ) + model = builder.build() + obs_start_date = _obs_date_for_dow( + target_first_day_dow=6, + n_init=model.latent.n_initialization_points, + ) + model.validate_data( + n_days_post_init=28, + obs_start_date=obs_start_date, + hospital={"obs": jnp.ones(4) * 5.0}, + ) + + def test_anchor_check_skipped_for_daily_obs(self): + """Daily observations do not require obs_start_date at validate_data time.""" + builder = _coherence_builder( + single_rt_process=AR1(autoreg=0.9, innovation_sd=0.05), + observations=[_daily_ed_counts()], + ) + model = builder.build() + n_total = model.latent.n_initialization_points + 30 + model.validate_data( + n_days_post_init=30, + ed={"obs": jnp.ones(n_total) * 5.0}, + ) + + def test_missing_obs_start_date_for_calendar_aligned_latent_raises(self): + """A calendar-week-aligned latent temporal process requires obs_start_date.""" + builder = _coherence_builder( + single_rt_process=StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ), + observations=[_daily_ed_counts()], + ) + model = builder.build() + n_total = model.latent.n_initialization_points + 30 + with pytest.raises( + ValueError, match="obs_start_date is required.*calendar_week" + ): + model.validate_data( + n_days_post_init=30, + ed={"obs": jnp.ones(n_total) * 5.0}, + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_subpopulation_infections.py b/test/test_subpopulation_infections.py index 658546e3..cc0adf58 100644 --- a/test/test_subpopulation_infections.py +++ b/test/test_subpopulation_infections.py @@ -84,6 +84,50 @@ def test_shape_and_positivity_across_subpop_counts( expected = jnp.sum(inf_all * fractions[jnp.newaxis, :], axis=1) assert jnp.allclose(inf_juris, expected, atol=1e-6) + def test_wrong_baseline_rt_shape_raises( + self, gen_int_rv, wrong_shape_temporal_process_cls + ): + """Baseline temporal process must return one daily baseline trajectory.""" + process = SubpopulationInfections( + name="subpopulation", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=wrong_shape_temporal_process_cls(jnp.zeros((16, 2))), + subpop_rt_deviation_process=RandomWalk(), + n_initialization_points=7, + ) + + with pytest.raises(ValueError, match="baseline_rt_process must return shape"): + process.sample( + n_days_post_init=10, + subpop_fractions=jnp.array([0.3, 0.25, 0.45]), + ) + + def test_wrong_subpop_deviation_shape_raises( + self, gen_int_rv, wrong_shape_temporal_process_cls, constant_temporal_process + ): + """Deviation temporal process must return one trajectory per subpop.""" + process = SubpopulationInfections( + name="subpopulation", + gen_int_rv=gen_int_rv, + I0_rv=DeterministicVariable("I0", 0.001), + log_rt_time_0_rv=DeterministicVariable("initial_log_rt", 0.0), + baseline_rt_process=constant_temporal_process, + subpop_rt_deviation_process=wrong_shape_temporal_process_cls( + jnp.zeros((17, 2)) + ), + n_initialization_points=7, + ) + + with pytest.raises( + ValueError, match="subpop_rt_deviation_process must return shape" + ): + process.sample( + n_days_post_init=10, + subpop_fractions=jnp.array([0.3, 0.25, 0.45]), + ) + class TestSubpopulationInfectionsValidation: """Test validation of inputs.""" diff --git a/test/test_temporal_processes.py b/test/test_temporal_processes.py index 7248518e..1ef3326f 100644 --- a/test/test_temporal_processes.py +++ b/test/test_temporal_processes.py @@ -6,7 +6,14 @@ import numpyro import pytest -from pyrenew.latent import AR1, DifferencedAR1, RandomWalk +from pyrenew.latent import AR1, DifferencedAR1, RandomWalk, StepwiseTemporalProcess +from pyrenew.time import MMWR_WEEK + +INNER_PROCESS_PARAMS = [ + (AR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (DifferencedAR1, {"autoreg": 0.9, "innovation_sd": 0.05}), + (RandomWalk, {"innovation_sd": 0.05}), +] class TestTemporalProcessVectorizedSampling: @@ -236,5 +243,205 @@ def test_differenced_ar1_trend_persistence(self): assert fraction_positive > 0.5 +class TestTemporalProcessStepSizeDefault: + """Standard temporal processes expose step_size=1 as a class attribute.""" + + @pytest.mark.parametrize( + "process_cls", + [AR1, DifferencedAR1, RandomWalk], + ) + def test_class_attribute_step_size_is_one(self, process_cls): + """Each concrete class has step_size=1 as a class attribute.""" + assert process_cls.step_size == 1 + + +class TestStepwiseTemporalProcessConstruction: + """Construction-time validation for StepwiseTemporalProcess.""" + + def test_step_size_attribute(self): + """step_size is exposed on the instance for builder inspection.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ) + assert wrapper.step_size == 7 + + def test_zero_step_size_raises(self): + """step_size=0 raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=0) + + def test_negative_step_size_raises(self): + """Negative step_size raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=-1) + + def test_float_step_size_raises(self): + """Non-integer step_size raises.""" + with pytest.raises(ValueError, match="positive integer"): + StepwiseTemporalProcess(AR1(autoreg=0.9, innovation_sd=0.05), step_size=7.0) + + def test_repr_includes_inner_and_step_size(self): + """__repr__ shows the inner process and step_size.""" + inner = AR1(autoreg=0.9, innovation_sd=0.05) + wrapper = StepwiseTemporalProcess(inner, step_size=7) + rendered = repr(wrapper) + assert rendered.startswith("StepwiseTemporalProcess(") + assert f"inner={inner!r}" in rendered + assert "step_size=7" in rendered + + def test_unknown_alignment_raises(self): + """Unknown alignment names raise.""" + with pytest.raises(ValueError, match="alignment"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="weekly", + ) + + def test_calendar_week_requires_weekly_step_size(self): + """calendar_week alignment is explicitly weekly.""" + with pytest.raises(ValueError, match="step_size=7"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=3, + alignment="calendar_week", + week=MMWR_WEEK, + ) + + def test_calendar_week_requires_week(self): + """calendar_week alignment requires a WeekCycle.""" + with pytest.raises(ValueError, match="week is required"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + ) + + def test_model_index_rejects_week(self): + """week is only meaningful for calendar_week alignment.""" + with pytest.raises(ValueError, match="week is only used"): + StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + week=MMWR_WEEK, + ) + + +class TestStepwiseTemporalProcessSample: + """Sample-time behavior of StepwiseTemporalProcess.""" + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_output_shape_divisible(self, inner_cls, inner_kwargs): + """n_timepoints divisible by step_size yields (n_timepoints, n_processes).""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=28, n_processes=3) + assert result.shape == (28, 3) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_output_shape_non_divisible(self, inner_cls, inner_kwargs): + """n_timepoints not divisible by step_size still yields n_timepoints rows.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=30, n_processes=2) + assert result.shape == (30, 2) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_broadcast_within_block(self, inner_cls, inner_kwargs): + """Each block of step_size consecutive timepoints is constant.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=28, n_processes=1) + for start in range(0, 28, 7): + block = result[start : start + 7] + assert jnp.allclose(block, block[0]) + + @pytest.mark.parametrize("inner_cls,inner_kwargs", INNER_PROCESS_PARAMS) + def test_broadcast_with_partial_final_block(self, inner_cls, inner_kwargs): + """A partial final block is still constant, just shorter than step_size.""" + wrapper = StepwiseTemporalProcess(inner_cls(**inner_kwargs), step_size=7) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=30, n_processes=1) + # The 4th block starts at row 28 and runs through row 29 (2 rows) + final_block = result[28:30] + assert jnp.allclose(final_block, final_block[0]) + + def test_step_size_one_passthrough_shape(self): + """step_size=1 yields (n_timepoints, n_processes), same as inner directly.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=1 + ) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample(n_timepoints=20, n_processes=2) + assert result.shape == (20, 2) + + def test_calendar_week_requires_first_day_dow_at_sample_time(self): + """calendar_week alignment needs the model-axis day of week.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ) + with numpyro.handlers.seed(rng_seed=42): + with pytest.raises(ValueError, match="first_day_dow"): + wrapper.sample(n_timepoints=20, n_processes=1) + + def test_calendar_week_alignment_with_leading_partial_week(self): + """calendar_week alignment starts full blocks on week_start_dow.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ) + # first_day_dow=3 means day 0 is Thursday. With Sunday week starts, + # days 0-2 are a leading partial week, then days 3-9 are the first + # full Sunday-Saturday block on the model axis. + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample( + n_timepoints=17, + n_processes=1, + first_day_dow=3, + ) + + assert jnp.allclose(result[:3], result[0]) + assert jnp.allclose(result[3:10], result[3]) + assert jnp.allclose(result[10:17], result[10]) + assert not jnp.allclose(result[2], result[3]) + + def test_calendar_week_alignment_without_leading_partial_week(self): + """calendar_week alignment handles model axes starting on week_start_dow.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), + step_size=7, + alignment="calendar_week", + week=MMWR_WEEK, + ) + with numpyro.handlers.seed(rng_seed=42): + result = wrapper.sample( + n_timepoints=15, + n_processes=1, + first_day_dow=6, + ) + + assert jnp.allclose(result[:7], result[0]) + assert jnp.allclose(result[7:14], result[7]) + assert jnp.allclose(result[14:15], result[14]) + + def test_coarse_trajectory_is_recorded(self): + """StepwiseTemporalProcess records the coarse trajectory.""" + wrapper = StepwiseTemporalProcess( + AR1(autoreg=0.9, innovation_sd=0.05), step_size=7 + ) + traced = numpyro.handlers.trace( + numpyro.handlers.seed(wrapper.sample, rng_seed=42) + ).get_trace(n_timepoints=15, n_processes=1, name_prefix="rt") + + assert "rt_coarse" in traced + assert traced["rt_coarse"]["type"] == "deterministic" + assert traced["rt_coarse"]["value"].shape == (3, 1) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/test/test_time.py b/test/test_time.py index a25bcd82..b6e86992 100644 --- a/test/test_time.py +++ b/test/test_time.py @@ -762,3 +762,52 @@ def test_create_date_time_spine_date_values(): assert dates[0] == dt.date(2025, 1, 1) assert dates[1] == dt.date(2025, 1, 2) assert dates[2] == dt.date(2025, 1, 3) + + +class TestWeekCycle: + """Tests for the :class:`pyrenew.time.WeekCycle` value object.""" + + def test_construction_and_end_dow(self): + """Construct with each valid ``start_dow`` and verify ``end_dow``.""" + for start in range(7): + cycle = ptime.WeekCycle(start_dow=start) + assert cycle.start_dow == start + assert cycle.end_dow == (start + 6) % 7 + + def test_mmwr_and_iso_constants(self): + """Named constants encode the MMWR and ISO weekly conventions.""" + assert ptime.MMWR_WEEK == ptime.WeekCycle(start_dow=6) + assert ptime.MMWR_WEEK.end_dow == 5 + assert ptime.ISO_WEEK == ptime.WeekCycle(start_dow=0) + assert ptime.ISO_WEEK.end_dow == 6 + + @pytest.mark.parametrize("bad", [-1, 7, 42]) + def test_invalid_start_dow_raises(self, bad): + """Out-of-range ``start_dow`` raises via :func:`validate_dow`.""" + with pytest.raises(ValueError, match="Day-of-week"): + ptime.WeekCycle(start_dow=bad) + + @pytest.mark.parametrize("bad", [None, "6", 6.0]) + def test_non_integer_start_dow_raises(self, bad): + """Non-integer ``start_dow`` raises via :func:`validate_dow`.""" + with pytest.raises(ValueError, match="must be integers"): + ptime.WeekCycle(start_dow=bad) + + def test_frozen_instance_rejects_mutation(self): + """``frozen=True`` blocks attribute assignment after construction.""" + cycle = ptime.WeekCycle(start_dow=3) + with pytest.raises(Exception): + cycle.start_dow = 4 + + def test_equality_and_hashability(self): + """Structural equality and set membership work as expected.""" + a = ptime.WeekCycle(start_dow=6) + b = ptime.WeekCycle(start_dow=6) + c = ptime.WeekCycle(start_dow=0) + assert a == b + assert a != c + assert len({a, b, c}) == 2 + + def test_not_equal_to_plain_tuple(self): + """Type-strict equality: a ``WeekCycle`` is not equal to ``(6,)``.""" + assert ptime.WeekCycle(start_dow=6) != (6,)