diff --git a/.github/ISSUE_TEMPLATE/development_task.md b/.github/ISSUE_TEMPLATE/development_task.md index 4c58ea0a..106826a9 100755 --- a/.github/ISSUE_TEMPLATE/development_task.md +++ b/.github/ISSUE_TEMPLATE/development_task.md @@ -1,33 +1,39 @@ --- name: Multisignal Renewal Dev Task Template -about: A template for writing specs following Multisignal Renewal project standards -title: '' +about: A template for writing specs following Multisignal Renewal project + standards +title: "" labels: ["development task"] -assignees: '' +assignees: "" --- -This template should be used as an outline. It may not be necessary to fill out every section. Delete this block of text and fill in anything in brackets. +This template should be used as an outline. +It may not be necessary to fill out every section. +Delete this block of text and fill in anything in brackets. Make sure you follow the project's standards specified in [this adr doc](https://github.com/cdcent/cfa-multisignal-renewal/blob/main/ADR/model/development_standards.md) (private link) ## Goal -[1-3 sentence summary of the issue or feature request. E.g. "We want to be able to ..."] +\[1-3 sentence summary of the issue or feature request. +E.g. "We want to be able to ..."\] ## Context -[Short paragraph describing how the issue arose and constraints imposed by the existing code architecture] +\[Short paragraph describing how the issue arose and constraints imposed by the existing code architecture\] ## Required features -- [Describe each thing you need the code to do to achieve the goal] -- [Example 1: Use a config to set input and output paths] -- [Example 2: Read in some-dataset and output some-transformed-dataset] +- \[Describe each thing you need the code to do to achieve the goal\] +- \[Example 1: Use a config to set input and output paths\] +- \[Example 2: Read in some-dataset and output some-transformed-dataset\] - etc... ## Specifications -[A checklist to keep track of details for each feature. At least one specification per feature is recommended. Edit the example below:] +\[A checklist to keep track of details for each feature. +At least one specification per feature is recommended. +Edit the example below:\] - [ ] EX2: A function that reads data from the `some-api` API and returns the dataset - [ ] EX2: Another function that inputs the dataset, performs $x$ transform, and outputs $y$ @@ -39,8 +45,8 @@ Make sure you follow the project's standards specified in [this adr doc](https:/ ## Out of scope -- [Things out of scope from this issue/PR] +- \[Things out of scope from this issue/PR\] ## Related documents -- [Link to related scripts, functions, issues, PRs, conversations, datasets, etc.] +- \[Link to related scripts, functions, issues, PRs, conversations, datasets, etc.\] diff --git a/.github/unused_templates/bug_report.md b/.github/unused_templates/bug_report.md deleted file mode 100755 index dd84ea78..00000000 --- a/.github/unused_templates/bug_report.md +++ /dev/null @@ -1,38 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '' -labels: '' -assignees: '' - ---- - -**Describe the bug** -A clear and concise description of what the bug is. - -**To Reproduce** -Steps to reproduce the behavior: -1. Go to '...' -2. Click on '....' -3. Scroll down to '....' -4. See error - -**Expected behavior** -A clear and concise description of what you expected to happen. - -**Screenshots** -If applicable, add screenshots to help explain your problem. - -**Desktop (please complete the following information):** - - OS: [e.g. iOS] - - Browser [e.g. chrome, safari] - - Version [e.g. 22] - -**Smartphone (please complete the following information):** - - Device: [e.g. iPhone6] - - OS: [e.g. iOS8.1] - - Browser [e.g. stock browser, safari] - - Version [e.g. 22] - -**Additional context** -Add any other context about the problem here. diff --git a/.github/unused_templates/feature_request.md b/.github/unused_templates/feature_request.md deleted file mode 100755 index bbcbbe7d..00000000 --- a/.github/unused_templates/feature_request.md +++ /dev/null @@ -1,20 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: '' -assignees: '' - ---- - -**Is your feature request related to a problem? Please describe.** -A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] - -**Describe the solution you'd like** -A clear and concise description of what you want to happen. - -**Describe alternatives you've considered** -A clear and concise description of any alternative solutions or features you've considered. - -**Additional context** -Add any other context or screenshots about the feature request here. diff --git a/.github/unused_templates/scientific-improvement.md b/.github/unused_templates/scientific-improvement.md deleted file mode 100755 index 4870c642..00000000 --- a/.github/unused_templates/scientific-improvement.md +++ /dev/null @@ -1,17 +0,0 @@ ---- -name: Scientific improvement -about: Suggest a way to improve an existing tool or pipeline -title: '' -labels: '' -assignees: '' - ---- - -## Describe the improvement that needs to be made -(e.g. update a parameter estimate, tweak the prior, modify the model) - -## Provide links to references to methods or data sources - -## Describe the changes expected to the model's outputs - -## Suggest new tests that will need to be implemented diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 581abf38..33c30d6f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: hooks: - id: check-added-large-files - id: check-yaml - args: [--unsafe] + args: [ --unsafe ] - id: check-toml - id: end-of-file-fixer - id: mixed-line-ending @@ -26,21 +26,19 @@ repos: - id: numpydoc-validation ##### # Quarto - - repo: local + - repo: https://github.com/jolars/panache + rev: v2.41.0 + hooks: - - id: format-qmd-python - name: Format Python in QMD - entry: python docs_scripts/quarto_python_formatter.py "-q --line-length 79" - language: python - files: \.qmd$ - additional_dependencies: [ruff] + - id: panache-format + - id: panache-lint ##### # Secrets - repo: https://github.com/Yelp/detect-secrets rev: v1.5.0 hooks: - id: detect-secrets - args: ["--baseline", ".secrets.baseline"] + args: [ "--baseline", ".secrets.baseline" ] exclude: package.lock.json #### # Typos @@ -48,4 +46,4 @@ repos: rev: v1 hooks: - id: typos - args: ["--force-exclude"] + args: [ "--force-exclude" ] diff --git a/README.md b/README.md index c3dc5db8..e760cef0 100755 --- a/README.md +++ b/README.md @@ -7,27 +7,29 @@ A renewal model estimates new infections from recent past infections using a gen From this, it infers $\mathcal{R}(t)$, the time-varying reproduction number, which indicates whether the number of infectious individuals is increasing or decreasing. The core renewal equation is: -$$I(t) = \mathcal{R}(t) \sum_{s} I(t-s) \, w(s)$$ +$$ +I(t) = \mathcal{R}(t) \sum_{s} I(t-s) \, w(s) +$$ where $w(s)$ is the generation interval distribution: the probability that $s$ time units separate infection in an index case and a secondary case. However inference is complicated by the fact that observational data require their own models ([Bhatt et al., 2023, §2](https://doi.org/10.1093/jrsssa/qnad030)). The observation equation links infections to expected observations: -$$\mu(t) = \alpha \sum_{s} I(t-s) \, \pi(s)$$ +$$ +\mu(t) = \alpha \sum_{s} I(t-s) \, \pi(s) +$$ where $\alpha$ is the ascertainment rate and $\pi(s)$ is the delay distribution from infection to observation. -The Pyrenew package provides configurable classes which encapsulate these components and methods to orchestrate the configuration and composition of these processes -resulting in programs which clearly express the model structure and choices, allowing for both ease of model specification and dissemination. +The Pyrenew package provides configurable classes which encapsulate these components and methods to orchestrate the configuration and composition of these processes resulting in programs which clearly express the model structure and choices, allowing for both ease of model specification and dissemination. The fundamental building block is the `RandomVariable` abstract base class, which allows for sampling from distributions, computing a mechanistic equation, or simply returning a fixed value. -We use `RandomVariable`s to build probabilistic models. We represent complete models as concrete subclasses of the `Model` abstract base class. +We use `RandomVariable`s to build probabilistic models. +We represent complete models as concrete subclasses of the `Model` abstract base class. The `PyrenewBuilder` class orchestrates the composition of `RandomVariables` into a `Model`. -PyRenew's strength lies in multi-signal integration for information pooling across diverse observed data streams -such as hospital admissions, wastewater concentrations, and emergency department visits -where each signal has distinct observation delays, noise characteristics, and spatial resolutions. +PyRenew's strength lies in multi-signal integration for information pooling across diverse observed data streams such as hospital admissions, wastewater concentrations, and emergency department visits where each signal has distinct observation delays, noise characteristics, and spatial resolutions. For single-signal renewal models, we recommend the excellent R package [EpiNow2](https://epiforecasts.io/EpiNow2/). ## Installation @@ -40,77 +42,63 @@ pip install git+https://github.com/CDCgov/PyRenew@main ## Models Implemented With PyRenew -- [CDCgov/pyrenew-covid-wastewater](https://github.com/CDCgov/pyrenew-covid-wastewater): _Models and infrastructure for forecasting COVID-19 hospitalizations using wastewater data with PyRenew._ -- [CDCgov/pyrenew-flu-light](https://github.com/CDCgov/pyrenew-flu-light/): _An instantiation in PyRenew of an influenza forecasting model used in the 2023-24 respiratory season._ +- [CDCgov/pyrenew-covid-wastewater](https://github.com/CDCgov/pyrenew-covid-wastewater): *Models and infrastructure for forecasting COVID-19 hospitalizations using wastewater data with PyRenew.* +- [CDCgov/pyrenew-flu-light](https://github.com/CDCgov/pyrenew-flu-light/): *An instantiation in PyRenew of an influenza forecasting model used in the 2023-24 respiratory season.* ## Resources -* [The PyRenew documentation suite](https://cdcgov.github.io/PyRenew) provides API reference documentation and tutorials on implementing multisignal renewal models with PyRenew. -* Additional reading on renewal processes in epidemiology - * [_Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes_](https://academic.oup.com/jrsssa/article-pdf/186/4/601/54770289/qnad030.pdf) - * [_Unifying incidence and prevalence under a time-varying general branching process_](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf) +- [The PyRenew documentation suite](https://cdcgov.github.io/PyRenew) provides API reference documentation and tutorials on implementing multisignal renewal models with PyRenew. +- Additional reading on renewal processes in epidemiology + - [_Semi-mechanistic Bayesian modelling of COVID-19 with renewal processes_](https://academic.oup.com/jrsssa/article-pdf/186/4/601/54770289/qnad030.pdf) + - [_Unifying incidence and prevalence under a time-varying general branching process_](https://link.springer.com/content/pdf/10.1007/s00285-023-01958-w.pdf) ## General Disclaimer -This repository was created for use by CDC programs to collaborate on public health related projects in support of the [CDC mission](https://www.cdc.gov/about/cdc/index.html#cdc_about_cio_mission-our-mission). GitHub is not hosted by the CDC, but is a third party website used by CDC and its partners to share information and collaborate on software. CDC use of GitHub does not imply an endorsement of any one particular service, product, or enterprise. +This repository was created for use by CDC programs to collaborate on public health related projects in support of the [CDC mission](https://www.cdc.gov/about/cdc/index.html#cdc_about_cio_mission-our-mission). +GitHub is not hosted by the CDC, but is a third party website used by CDC and its partners to share information and collaborate on software. +CDC use of GitHub does not imply an endorsement of any one particular service, product, or enterprise. ## Public Domain Standard Notice -This repository constitutes a work of the United States Government and is not -subject to domestic copyright protection under 17 USC § 105. This repository is in -the public domain within the United States, and copyright and related rights in -the work worldwide are waived through the [CC0 1.0 Universal public domain dedication](https://creativecommons.org/publicdomain/zero/1.0/). -All contributions to this repository will be released under the CC0 dedication. By -submitting a pull request you are agreeing to comply with this waiver of -copyright interest. +This repository constitutes a work of the United States Government and is not subject to domestic copyright protection under 17 USC § 105. +This repository is in the public domain within the United States, and copyright and related rights in the work worldwide are waived through the [CC0 1.0 Universal public domain dedication](https://creativecommons.org/publicdomain/zero/1.0/). +All contributions to this repository will be released under the CC0 dedication. +By submitting a pull request you are agreeing to comply with this waiver of copyright interest. ## License Standard Notice This repository is licensed under ASL v2 or later. -This source code in this repository is free: you can redistribute it and/or modify it under -the terms of the Apache Software License version 2, or (at your option) any -later version. +This source code in this repository is free: you can redistribute it and/or modify it under the terms of the Apache Software License version 2, or (at your option) any later version. -This source code in this repository is distributed in the hope that it will be useful, but WITHOUT ANY -WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A -PARTICULAR PURPOSE. See the Apache Software License for more details. +This source code in this repository is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. +See the Apache Software License for more details. -You should have received a copy of the Apache Software License along with this -program. If not, see [http://www.apache.org/licenses/LICENSE-2.0.html](http://www.apache.org/licenses/LICENSE-2.0.html). +You should have received a copy of the Apache Software License along with this program. +If not, see [http://www.apache.org/licenses/LICENSE-2.0.html](http://www.apache.org/licenses/LICENSE-2.0.html). The source code forked from other open source projects will inherit its license. ## Privacy Standard Notice -This repository contains only non-sensitive, publicly available data and -information. All material and community participation is covered by the -[Disclaimer](https://github.com/CDCgov/template/blob/master/DISCLAIMER.md) -and [Code of Conduct](https://github.com/CDCgov/template/blob/master/code-of-conduct.md). +This repository contains only non-sensitive, publicly available data and information. +All material and community participation is covered by the [Disclaimer](https://github.com/CDCgov/template/blob/master/DISCLAIMER.md) and [Code of Conduct](https://github.com/CDCgov/template/blob/master/code-of-conduct.md). For more information about CDC's privacy policy, please visit [http://www.cdc.gov/other/privacy.html](https://www.cdc.gov/other/privacy.html). ## Contributing Standard Notice -Anyone is encouraged to contribute to the repository by [forking](https://help.github.com/articles/fork-a-repo) -and submitting a pull request. (If you are new to GitHub, you might start with a -[basic tutorial](https://help.github.com/articles/set-up-git).) By contributing -to this project, you grant a world-wide, royalty-free, perpetual, irrevocable, -non-exclusive, transferable license to all users under the terms of the -[Apache Software License v2](http://www.apache.org/licenses/LICENSE-2.0.html) or -later. +Anyone is encouraged to contribute to the repository by [forking](https://help.github.com/articles/fork-a-repo) and submitting a pull request. +(If you are new to GitHub, you might start with a [basic tutorial](https://help.github.com/articles/set-up-git).) +By contributing to this project, you grant a world-wide, royalty-free, perpetual, irrevocable, non-exclusive, transferable license to all users under the terms of the [Apache Software License v2](http://www.apache.org/licenses/LICENSE-2.0.html) or later. -All comments, messages, pull requests, and other submissions received through -CDC including this GitHub page may be subject to applicable federal law, including but not limited to the Federal Records Act, and may be archived. Learn more at [http://www.cdc.gov/other/privacy.html](http://www.cdc.gov/other/privacy.html). +All comments, messages, pull requests, and other submissions received through CDC including this GitHub page may be subject to applicable federal law, including but not limited to the Federal Records Act, and may be archived. +Learn more at [http://www.cdc.gov/other/privacy.html](http://www.cdc.gov/other/privacy.html). ## Records Management Standard Notice -This repository is not a source of government records but is a copy to increase -collaboration and collaborative potential. All government records will be -published through the [CDC web site](http://www.cdc.gov). +This repository is not a source of government records but is a copy to increase collaboration and collaborative potential. +All government records will be published through the [CDC web site](http://www.cdc.gov). ## Additional Standard Notices -Please refer to [CDC's Template Repository](https://github.com/CDCgov/template) -for more information about [contributing to this repository](https://github.com/CDCgov/template/blob/master/CONTRIBUTING.md), -[public domain notices and disclaimers](https://github.com/CDCgov/template/blob/master/DISCLAIMER.md), -and [code of conduct](https://github.com/CDCgov/template/blob/master/code-of-conduct.md). +Please refer to [CDC's Template Repository](https://github.com/CDCgov/template) for more information about [contributing to this repository](https://github.com/CDCgov/template/blob/master/CONTRIBUTING.md), [public domain notices and disclaimers](https://github.com/CDCgov/template/blob/master/DISCLAIMER.md), and [code of conduct](https://github.com/CDCgov/template/blob/master/code-of-conduct.md). diff --git a/docs/developer_documentation.md b/docs/developer_documentation.md index 9ee7ea94..380eafa1 100644 --- a/docs/developer_documentation.md +++ b/docs/developer_documentation.md @@ -1,14 +1,16 @@ # Developer Documentation -**Note: this document is a work in progress. Contrbitions to all sections are welcome.** +**Note: this document is a work in progress. +Contrbitions to all sections are welcome.** ## GitHub Workflow -- Reviews from all of `PyRenew-devs` are encouraged, but an approving review from a [codeowner](https://github.com/CDCgov/PyRenew/blob/main/.github/CODEOWNERS) ([@dylanhmorris](https://github.com/dylanhmorris) or [@damonbayer](https://github.com/damonbayer) is required before a pull request can be merged to `main`. +- Reviews from all of `PyRenew-devs` are encouraged, but an approving review from a [codeowner](https://github.com/CDCgov/PyRenew/blob/main/.github/CODEOWNERS) ([@dylanhmorris](https://github.com/dylanhmorris) or [@damonbayer](https://github.com/damonbayer) is required before a pull request can be merged to `main`. - For CDC contributors: if your pull request has not received a review at the time of the next standup, use standup to find a reviewer. - External contributors should expect to receive a review within a few days of creating a pull request. - If you create a draft pull request, indicate what, if anything, about the current pull request should be reviewed. -- Only mark a pull request as “ready for review” if you think it is ready to be merged. This indicates that a thorough, all-encompassing review should be given. +- Only mark a pull request as “ready for review” if you think it is ready to be merged. + This indicates that a thorough, all-encompassing review should be given. ## Installation for Developers @@ -22,30 +24,31 @@ A variety of coding conventions are enforced by automated tools in continuous in ## PyRenew Principles - Variable naming conventions - + Use the `data_` prefix for (potentially) observed data. - + Use the `_rv` suffix for random variables. - + Use the `observed_` for the output of sample statements where `obs` is a `data_` prefixed object. - + Thus, code which may reasonably written like `infections = infections.sample(x, obs=infections)` should instead be written `observed_infections = infections_rv.sample(x, obs=data_infections)`. - + - Use the `data_` prefix for (potentially) observed data. + - Use the `_rv` suffix for random variables. + - Use the `observed_` for the output of sample statements where `obs` is a `data_` prefixed object. + - Thus, code which may reasonably written like `infections = infections.sample(x, obs=infections)` should instead be written `observed_infections = infections_rv.sample(x, obs=data_infections)`. - Class conventions - + Composing models is discouraged. - + Returning anything from `Model.sample` is discouraged. Instead, sample from models using `Predictive` or our `prior_predictive` or `posterior_predictive` functions. - + Using `numpyro.deterministic` within a `RandomVariable` is discouraged. Only use at the `numpyro.deterministic` `Model` level. If something might need to be recorded from a `RandomVariable`, it should be returned from the `RandomVariable` so it can be recorded at the `Model` level. - + Using default site names in a `RandomVariable` is discouraged. Only use default site names at the `Model` level. - + Use `DeterministicVariable`s instead of constants within a model. - + - Composing models is discouraged. + - Returning anything from `Model.sample` is discouraged. + Instead, sample from models using `Predictive` or our `prior_predictive` or `posterior_predictive` functions. + - Using `numpyro.deterministic` within a `RandomVariable` is discouraged. + Only use at the `numpyro.deterministic` `Model` level. + If something might need to be recorded from a `RandomVariable`, it should be returned from the `RandomVariable` so it can be recorded at the `Model` level. + - Using default site names in a `RandomVariable` is discouraged. + Only use default site names at the `Model` level. + - Use `DeterministicVariable`s instead of constants within a model. - `scan` conventions - + Use `jax.lax.scan` for any scan whose iterations are deterministic, i.e. iterations contain no internal calls to `RandomVariable.sample()` or `numpyro.sample()`. - + Use `numpyro.scan` for any scan whose the iterations are stochastic, i.e. the iterations potentially include calls to `RandomVariable.sample()` or `numpyro.sample()`. - + - Use `jax.lax.scan` for any scan whose iterations are deterministic, i.e. iterations contain no internal calls to `RandomVariable.sample()` or `numpyro.sample()`. + - Use `numpyro.scan` for any scan whose the iterations are stochastic, i.e. the iterations potentially include calls to `RandomVariable.sample()` or `numpyro.sample()`. - Multidimensional array conventions - + In a multidimensional array of timeseries, time is always the first dimension. By default, `jax.lax.scan()` and `numpyro.contrib.control_flow.scan()` build output arrays by augmenting the first dimension, and variables are often scanned over time, making default output of scan over time sensible. + - In a multidimensional array of timeseries, time is always the first dimension. + By default, `jax.lax.scan()` and `numpyro.contrib.control_flow.scan()` build output arrays by augmenting the first dimension, and variables are often scanned over time, making default output of scan over time sensible. ## Documenting code for MkDocs The project uses [MkDocs](https://www.mkdocs.org/) and [mkdocstrings](https://mkdocstrings.github.io/) to generate documentation. -MkDocs builds the documentation pages from the source files contained in the `docs` directory -and these, in turn, contain `mkdocstrings` directives to include the docstrings in the source code file. +MkDocs builds the documentation pages from the source files contained in the `docs` directory and these, in turn, contain `mkdocstrings` directives to include the docstrings in the source code file. The top-level `Makefile` task `docs` will build the site locally in a new directory `site` @@ -73,11 +76,12 @@ The `make docs` Makefile task first renders the `.qmd` files to `.md`, then runs To make the new tutorial available in the website, developers should follow these steps: -1. Create a new `quarto` file in the `./docs/tutorials` directory. For instance, the `example_with_datasets.qmd` file was added to the repository. -2. Add an entry in the `./docs//tutorials/.pages` file to specify the order in which this tutorial will appear in the navigation sidebar. The entry specifies the *plain markdown* filename. +1. Create a new `quarto` file in the `./docs/tutorials` directory. + For instance, the `example_with_datasets.qmd` file was added to the repository. +2. Add an entry in the `./docs//tutorials/.pages` file to specify the order in which this tutorial will appear in the navigation sidebar. + The entry specifies the *plain markdown* filename. -For example, if you are adding a tutorial named `seasonal_effects.qmd`, then you would update the -file `docs/tutorials/.pages` as follows +For example, if you are adding a tutorial named `seasonal_effects.qmd`, then you would update the file `docs/tutorials/.pages` as follows ``` arrange: @@ -90,13 +94,12 @@ arrange: - seasonal_effects.md ``` - - ### Adding new pages To add a new page which is neither source code documentation nor a tutorial: -1. Create a `md` file in the appropriate directory. For example, this file about development was added under `./docs/source/developer_documentation.md`. +1. Create a `md` file in the appropriate directory. + For example, this file about development was added under `./docs/source/developer_documentation.md`. 2. Make sure the new `md` file is included in the `.pages` file for that directory. ``` diff --git a/docs/index.md b/docs/index.md index ed973294..db7d58f9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,8 @@ It combines two distinct discrete convolutions which describe different processe New infections arise from past infections through a generation interval distribution. Let $I(t)$ denote the latent number of new infections at time $t$, and let $\mathcal{R}(t)$ denote the time-varying reproduction number. -Assume the generation interval distribution has finite support over lags $\tau = 1, \dots, G$. Let $w_\tau$ denote the probability that a secondary infection occurs $\tau$ days after infection in the primary case, with +Assume the generation interval distribution has finite support over lags $\tau = 1, \dots, G$. +Let $w_\tau$ denote the probability that a secondary infection occurs $\tau$ days after infection in the primary case, with $$ \sum_{\tau=1}^{G} w_\tau = 1, \qquad w_\tau \ge 0. @@ -38,7 +39,9 @@ In PyRenew, the latent process is represented on a **per-capita scale** (infecti Infections are latent and are not directly observed; instead, the data consist of events that occur some time after infection, such as hospitalizations or emergency department visits. -Let $\mu(t)$ denote the expected number of observed events at time $t$, and let $\alpha$ denote an **ascertainment rate**, the probability an infection is observed as an event. Assume the delay from infection to observation has finite support over lags $d = 0, \dots, D$. Let $\pi_d$ denote the probability that an infection is observed $d$ days later, with +Let $\mu(t)$ denote the expected number of observed events at time $t$, and let $\alpha$ denote an **ascertainment rate**, the probability an infection is observed as an event. +Assume the delay from infection to observation has finite support over lags $d = 0, \dots, D$. +Let $\pi_d$ denote the probability that an infection is observed $d$ days later, with $$ \sum_{d=0}^{D} \pi_d = 1, \qquad \pi_d \ge 0. @@ -56,14 +59,16 @@ Here, $d$ indexes lags in the infection-to-observation delay distribution. The observation equation defines the expected number of observed events at time $t$, but the actual observed data are stochastic. -Let $Y(t)$ denote the observed number of events at time $t$. We model observations as draws from a count distribution with central value (typically mean) $\mu(t)$: +Let $Y(t)$ denote the observed number of events at time $t$. +We model observations as draws from a count distribution with central value (typically mean) $\mu(t)$: $$ Y(t) \sim \text{Distribution}(\mu(t), \theta). $$ One possible choice is the Poisson distribution, which assumes the variance equals the mean. -In practice, epidemiological count data are often overdispersed relative to the Poisson. Negative binomial distributions are a common choice for modeling these overdispersed counts. +In practice, epidemiological count data are often overdispersed relative to the Poisson. +Negative binomial distributions are a common choice for modeling these overdispersed counts. The model thus has two layers: @@ -83,7 +88,6 @@ PyRenew's building blocks are: Components (generation interval, reproduction number process, observation process) are specified independently, so each can be swapped without changing the rest of the model. This makes it straightforward to move a quantity between "known" and "inferred" and keeps modeling choices explicit and reviewable. - ## Multi-signal models PyRenew's strength lies in multi-signal integration: pooling information across diverse observed data streams such as hospital admissions, wastewater concentrations, and emergency department visits, where each signal has distinct observation delays, noise characteristics, and spatial resolutions. diff --git a/docs/tutorials/building_multisignal_models.qmd b/docs/tutorials/building_multisignal_models.qmd index 565d481a..d0b4d526 100644 --- a/docs/tutorials/building_multisignal_models.qmd +++ b/docs/tutorials/building_multisignal_models.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -18,8 +18,9 @@ jupyter: --- ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import numpyro # to run samplers in parallel you must run `set_host_device_count` before importing jax @@ -28,7 +29,8 @@ numpyro.enable_x64() ``` ```{python} -# | label: imports-base +#| label: imports-base + import arviz as az import jax import jax.numpy as jnp @@ -53,7 +55,8 @@ def make_rng_key(): ``` ```{python} -# | label: imports-pyrenew +#| label: imports-pyrenew + from jax.typing import ArrayLike from pyrenew import datasets @@ -89,20 +92,23 @@ from pyrenew.time import MMWR_WEEK Renewal models in PyRenew combine two types of components: -1. **Latent infection process**: Generates unobserved infections via the renewal equation, driven by a time-varying reproduction number $\mathcal{R}(t)$ +1. **Latent infection process**: Generates unobserved infections via the renewal equation, driven by a time-varying reproduction number $\mathcal{R}(t)$ -2. **Observation processes**: Transform latent infections into observable signals (hospital admissions, wastewater concentrations, etc.) by applying delays, ascertainment, and noise +2. **Observation processes**: Transform latent infections into observable signals (hospital admissions, wastewater concentrations, etc.) by applying delays, ascertainment, and noise -A **multi-signal model** combines multiple observation processes—each representing a different data stream, e.g., hospital admissions, emergency deparatment visits, wastewater concentrations, which stem from the same underlying latent infection process. By jointly modeling these signals, we can improve estimation and prediction of the time-varying reproduction number $\mathcal{R}(t)$. Such a model must: +A **multi-signal model** combines multiple observation processes---each representing a different data stream, e.g., hospital admissions, emergency deparatment visits, wastewater concentrations, which stem from the same underlying latent infection process. +By jointly modeling these signals, we can improve estimation and prediction of the time-varying reproduction number $\mathcal{R}(t)$. +Such a model must: -- Generate a single coherent infection trajectory (or set of trajectories for subpopulations) -- Route those infections to each observation process appropriately -- Handle the initialization period required by delay distributions +- Generate a single coherent infection trajectory (or set of trajectories for subpopulations) +- Route those infections to each observation process appropriately +- Handle the initialization period required by delay distributions -The `PyrenewBuilder` class handles this plumbing. You specify: +The `PyrenewBuilder` class handles this plumbing. +You specify: -1. A single **latent process** (e.g., `SubpopulationInfections`) that defines how infections evolve. -2. One or more **observation processes** (e.g., `PopulationCounts`, `MeasurementObservation`) that define how infections become data. +1. A single **latent process** (e.g., `SubpopulationInfections`) that defines how infections evolve. +2. One or more **observation processes** (e.g., `PopulationCounts`, `MeasurementObservation`) that define how infections become data. The builder computes initialization requirements, wires components together, and produces a model ready for inference. @@ -110,9 +116,9 @@ The builder computes initialization requirements, wires components together, and Before diving into multi-signal models, you may want to review these foundational tutorials: -- **[Latent Infections](latent_infections.md)** and **[Latent Subpopulation Infections](latent_subpopulation_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$ -- **[Observation Processes: Counts](observation_processes_counts.md)**: Modeling count data (admissions, deaths) -- **[Observation Processes: Measurements](observation_processes_measurements.md)**: Modeling continuous data (wastewater) +- **[Latent Infections](latent_infections.md)** and **[Latent Subpopulation Infections](latent_subpopulation_infections.md)**: Understanding temporal process choices for $\mathcal{R}(t)$ +- **[Observation Processes: Counts](observation_processes_counts.md)**: Modeling count data (admissions, deaths) +- **[Observation Processes: Measurements](observation_processes_measurements.md)**: Modeling continuous data (wastewater) This tutorial shows how to combine these components into a complete multi-signal model. @@ -120,18 +126,20 @@ This tutorial shows how to combine these components into a complete multi-signal This tutorial demonstrates building a multi-signal renewal model using: -- `SubpopulationInfections` — subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations -- `PopulationCounts` — hospital admissions (jurisdiction-level) -- A custom `Wastewater` class — viral concentrations (subpopulation-level) +- `SubpopulationInfections` --- subpopulations share a jurisdiction-level baseline $\mathcal{R}(t)$ with subpopulation-specific deviations +- `PopulationCounts` --- hospital admissions (jurisdiction-level) +- A custom `Wastewater` class --- viral concentrations (subpopulation-level) ## Model Structure In this tutorial, we build a model that jointly fits two data streams to a shared latent infection process: -- **Hospital admissions** — jurisdiction-level counts that reflect *total* infections across all subpopulations, delayed and underascertained -- **Wastewater concentrations** — site-level measurements from a subset of subpopulations (catchment areas), reflecting viral shedding and dilution +- **Hospital admissions** --- jurisdiction-level counts that reflect *total* infections across all subpopulations, delayed and underascertained +- **Wastewater concentrations** --- site-level measurements from a subset of subpopulations (catchment areas), reflecting viral shedding and dilution -The diagram below shows how data flows through the model. The latent process generates infection trajectories for all subpopulations. Each observation process receives the infections it needs — aggregated totals or per-subpopulation arrays — and transforms them into predicted observations via delays, ascertainment, shedding kinetics, and noise. +The diagram below shows how data flows through the model. +The latent process generates infection trajectories for all subpopulations. +Each observation process receives the infections it needs --- aggregated totals or per-subpopulation arrays --- and transforms them into predicted observations via delays, ascertainment, shedding kinetics, and noise. ```mermaid flowchart TB @@ -165,33 +173,37 @@ flowchart TB ### Infection Resolution -Different observation processes observe different levels of the model hierarchy. Each observation process declares an **infection resolution** that determines what infection data it receives: +Different observation processes observe different levels of the model hierarchy. +Each observation process declares an **infection resolution** that determines what infection data it receives: -| Resolution | Receives | Example signals | -|----------------------|-------------------|-------------------------------| -| `"aggregate"` | Aggregated infections (sum across all subpopulations), shape `(T,)` | Hospital admissions, case counts | -| `"subpop"` | Infection matrix for all subpopulations, shape `(T, n_subpops)` | Wastewater, site-specific surveillance | + | Resolution | Receives | Example signals | + | ------------- | ------------------------------------------------------------------- | -------------------------------------- | + | `"aggregate"` | Aggregated infections (sum across all subpopulations), shape `(T,)` | Hospital admissions, case counts | + | `"subpop"` | Infection matrix for all subpopulations, shape `(T, n_subpops)` | Wastewater, site-specific surveillance | The `PyrenewBuilder` routes latent infections to observation processes based on each process's declared resolution. -For subpopulation-level observations, the observation process selects which subpopulations it observes using `subpop_indices` provided at sample/fit time. This allows flexible observation patterns—for example, wastewater samples might only cover 5 of 6 subpopulations (catchment areas), while the 6th represents areas without wastewater monitoring. +For subpopulation-level observations, the observation process selects which subpopulations it observes using `subpop_indices` provided at sample/fit time. +This allows flexible observation patterns---for example, wastewater samples might only cover 5 of 6 subpopulations (catchment areas), while the 6th represents areas without wastewater monitoring. With this structure in mind, we'll now define each component following the generative direction: first the latent infection process, then the observation processes. ## Latent Infection Process -Latent infection processes implement the renewal equation to generate infection trajectories. All latent processes share common components: +Latent infection processes implement the renewal equation to generate infection trajectories. +All latent processes share common components: -- **Generation interval**: PMF for secondary infection timing -- **Initial infections (I0)**: Starting condition for the renewal process -- **Temporal dynamics**: How $\mathcal{R}(t)$ evolves over time +- **Generation interval**: PMF for secondary infection timing +- **Initial infections (I0)**: Starting condition for the renewal process +- **Temporal dynamics**: How $\mathcal{R}(t)$ evolves over time ### Generation Interval The generation interval PMF specifies the probability that a secondary infection occurs $\tau$ days after the primary infection. ```{python} -# | label: gen-interval +#| label: gen-interval + covid_gen_int = [0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02] gen_int_pmf = jnp.array(covid_gen_int) gen_int_rv = DeterministicPMF("gen_int", gen_int_pmf) @@ -203,37 +215,38 @@ print(f"Generation interval length: {len(gen_int_pmf)} days") ### I0: Initial Infections -The initial infections RV `I0_rv` specifies the **proportion of the population infected** at the first observation time. This must be a value in the interval (0, 1]. We use a Beta prior centered near a small value: +The initial infections RV `I0_rv` specifies the **proportion of the population infected** at the first observation time. +This must be a value in the interval (0, 1\]. +We use a Beta prior centered near a small value: ```{python} -# | label: initial-infections +#| label: initial-infections + I0_rv = DistributionalVariable("I0", dist.Beta(1, 100)) ``` - - ### Log Rt at time $0$ We place a prior on the log(Rt) at time $0$, centered at 0.0 (Rt = 1.0) with moderate uncertainty: ```{python} -# | label: log-rt-time-0 -log_rt_time_0_rv = DistributionalVariable( - "log_rt_time_0", dist.Normal(0.0, 0.5) -) +#| label: log-rt-time-0 + +log_rt_time_0_rv = DistributionalVariable("log_rt_time_0", dist.Normal(0.0, 0.5)) ``` ### Temporal Processes for $\mathcal{R}(t)$ We configure two temporal processes: -- **Jurisdiction-level** (`baseline_rt_process`): AR(1) process for the baseline $\mathcal{R}(t)$ -- **Subpopulation-level** (`subpop_rt_deviation_process`): RandomWalk for subpopulation deviations +- **Jurisdiction-level** (`baseline_rt_process`): AR(1) process for the baseline $\mathcal{R}(t)$ +- **Subpopulation-level** (`subpop_rt_deviation_process`): RandomWalk for subpopulation deviations The RandomWalk allows flexible evolution of subpopulation-specific transmission without mean reversion. ```{python} -# | label: temporal-processes +#| label: temporal-processes + # AR1 provides mean-reverting behavior for baseline Rt baseline_rt_process = AR1(autoreg=0.9, innovation_sd=0.05) @@ -246,9 +259,9 @@ subpop_rt_deviation_process = RandomWalk(innovation_sd=0.025) The renewal equation is evaluated on the model's daily time axis, but the temporal process for $\mathcal{R}(t)$ does not have to sample a new parameter every day. This separates three model choices: -- **Parameter cadence**: how often the $\mathcal{R}(t)$ temporal process samples a new latent value -- **Model time axis**: the daily axis used by the renewal equation and delay convolutions -- **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions +- **Parameter cadence**: how often the $\mathcal{R}(t)$ temporal process samples a new latent value +- **Model time axis**: the daily axis used by the renewal equation and delay convolutions +- **Observation cadence**: the temporal granularity for each signal, such as daily ED visits or weekly hospital admissions The AR(1) process above samples one value per model day: @@ -276,33 +289,35 @@ The model handles the calendar bookkeeping and forwards the day-of-week informat ## Observation Processes -Observation processes transform latent infections into observable signals and define the statistical model linking predictions to data. Each observation process: +Observation processes transform latent infections into observable signals and define the statistical model linking predictions to data. +Each observation process: -- Has a unique **name** that identifies the signal in model outputs -- Declares what **infection resolution** it needs (`"aggregate"` or `"subpop"`) -- Applies signal-specific transformations (ascertainment, delay convolution, shedding kinetics) -- Defines the noise model +- Has a unique **name** that identifies the signal in model outputs +- Declares what **infection resolution** it needs (`"aggregate"` or `"subpop"`) +- Applies signal-specific transformations (ascertainment, delay convolution, shedding kinetics) +- Defines the noise model ### Signal Naming -Each observation process requires a `name` parameter—a short, meaningful identifier like `"hospital"` or `"wastewater"`. This name serves as the single identifier for the signal throughout the model: +Each observation process requires a `name` parameter---a short, meaningful identifier like `"hospital"` or `"wastewater"`. +This name serves as the single identifier for the signal throughout the model: -- **Numpyro sites**: Prefixes all sample and deterministic sites (e.g., `hospital_obs`, `hospital_predicted`) -- **Data binding**: Becomes the keyword argument for passing data to `model.run()` (e.g., `hospital={...}`) +- **Numpyro sites**: Prefixes all sample and deterministic sites (e.g., `hospital_obs`, `hospital_predicted`) +- **Data binding**: Becomes the keyword argument for passing data to `model.run()` (e.g., `hospital={...}`) This unified naming provides several benefits: -- **Interpretable outputs**: When examining MCMC samples or posterior diagnostics, site names like `hospital_predicted` immediately indicate which signal each quantity refers to -- **Multiple signals of the same type**: You can include multiple count observations (e.g., hospital admissions and deaths) by giving each a distinct name -- **Clearer debugging**: Error messages and trace inspection show meaningful signal names rather than generic identifiers +- **Interpretable outputs**: When examining MCMC samples or posterior diagnostics, site names like `hospital_predicted` immediately indicate which signal each quantity refers to +- **Multiple signals of the same type**: You can include multiple count observations (e.g., hospital admissions and deaths) by giving each a distinct name +- **Clearer debugging**: Error messages and trace inspection show meaningful signal names rather than generic identifiers ### Hospital Admissions -In this example we use a dataset consisting of hospital admissions for COVID-19 across California -for the first 10 months of 2023 (as reported to the CDC). +In this example we use a dataset consisting of hospital admissions for COVID-19 across California for the first 10 months of 2023 (as reported to the CDC). ```{python} -# | label: load-hospital-data +#| label: load-hospital-data + # Load daily hospital admissions for California ca_hosp_data = datasets.load_hospital_data_for_state("CA", "2023-11-06.csv") obs_start_date = ca_hosp_data["dates"][0] @@ -314,15 +329,14 @@ print("State: California") print(f"Population: {population_size:,}") print(f"Date range: {ca_hosp_data['dates'][0]} to {ca_hosp_data['dates'][-1]}") print(f"Number of days: {n_hosp_days}") -print( - f"Admissions range: {int(hosp_admits.min())} to {int(hosp_admits.max())}" -) +print(f"Admissions range: {int(hosp_admits.min())} to {int(hosp_admits.max())}") ``` The hospital admissions data is aggregated at the jurisdiction level, therefore we specify a `PopulationCounts` observation process. ```{python} -# | label: hospital-obs-process +#| label: hospital-obs-process + # Infection-to-hospitalization delay (COVID-19, from literature) inf_to_hosp_pmf = jnp.array( [ @@ -396,12 +410,13 @@ print(f" Delay PMF length: {len(inf_to_hosp_pmf)} days") #### Wastewater Observation Process -The `MeasurementObservation` base class handles continuous observation processes. Domain-specific -implementations subclass it and implement `_predicted_obs()` to transform infections -into predicted values. See `observation_processes_measurements.md` for a detailed tutorial. +The `MeasurementObservation` base class handles continuous observation processes. +Domain-specific implementations subclass it and implement `_predicted_obs()` to transform infections into predicted values. +See `observation_processes_measurements.md` for a detailed tutorial. ```{python} -# | label: wastewater-class +#| label: wastewater-class + class Wastewater(MeasurementObservation): """ Wastewater viral concentration observation process. @@ -418,9 +433,7 @@ class Wastewater(MeasurementObservation): ml_per_person_per_day: float, noise: MeasurementNoise, ) -> None: - super().__init__( - name=name, temporal_pmf_rv=shedding_kinetics_rv, noise=noise - ) + super().__init__(name=name, temporal_pmf_rv=shedding_kinetics_rv, noise=noise) self.log10_genome_per_infection_rv = log10_genome_per_infection_rv self.ml_per_person_per_day = ml_per_person_per_day @@ -442,23 +455,19 @@ class Wastewater(MeasurementObservation): ) return convolved - shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( - infections - ) + shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)(infections) genome_copies = 10**log10_genome - concentration = ( - shedding_signal * genome_copies / self.ml_per_person_per_day - ) + concentration = shedding_signal * genome_copies / self.ml_per_person_per_day return jnp.log(concentration) ``` #### Wastewater Data -For the wastewater data, we use a simulated dataset for California with realistic noise patterns -that covers the same time period. +For the wastewater data, we use a simulated dataset for California with realistic noise patterns that covers the same time period. ```{python} -# | label: load-wastewater-data +#| label: load-wastewater-data + # Load wastewater data for California ca_ww_data = datasets.load_wastewater_data_for_state("CA", "fake_nwss.csv") @@ -473,9 +482,7 @@ print("State: California") print(f"Number of sites: {ww_n_sites}") print(f"Number of observations: {ww_n_obs}") print(f"Date range: {ca_ww_data['dates'][0]} to {ca_ww_data['dates'][-1]}") -print( - f"Time index range: {int(ww_time_indices.min())} to {int(ww_time_indices.max())}" -) +print(f"Time index range: {int(ww_time_indices.min())} to {int(ww_time_indices.max())}") print("\nSites:") for i, name in enumerate(ww_wwtp_names[:5]): print(f" {i}: {name}") @@ -483,12 +490,15 @@ if ww_n_sites > 5: print(f" ... and {ww_n_sites - 5} more") ``` -Wastewater observations are site-level: each measurement is associated with a specific measurement site. The Wastewater observation process uses LogNormalNoise, which takes hierarchical priors for the site-level mode and standard deviation parameters. This enables partial pooling across measurement sites. +Wastewater observations are site-level: each measurement is associated with a specific measurement site. +The Wastewater observation process uses LogNormalNoise, which takes hierarchical priors for the site-level mode and standard deviation parameters. +This enables partial pooling across measurement sites. Here we specify HierarchicalNormalPrior for the site-level mode and GammaGroupSdPrior for the standard deviation. ```{python} -# | label: wastewater-obs-process +#| label: wastewater-obs-process + # Viral shedding kinetics PMF (days post-infection) shedding_pmf = jnp.array( [ @@ -544,25 +554,28 @@ print(f" Shedding PMF length: {len(shedding_pmf)} days") We instantiate a `PyrenewBuilder` object which handles the composition of the latent infection process and the observation process. ```{python} -# | label: model-builder-init +#| label: model-builder-init + # Build the multi-signal model builder = PyrenewBuilder() ``` The `PyrenewBuilder` object has 3 key methods: -- `configure_latent` -- `add_observation` -- `build` +- `configure_latent` +- `add_observation` +- `build` -Methods `configure_latent` and `add_observation` can be called in any order. Method `build` is called once all processes have been specified in the model. +Methods `configure_latent` and `add_observation` can be called in any order. +Method `build` is called once all processes have been specified in the model. ### Configuring the Latent Process We use `configure_latent` to specify the **model structure**: generation interval, initial infections, and temporal dynamics. ```{python} -# | label: configure-latent +#| label: configure-latent + print("Latent process configuration:") print(f" Generation interval length: {len(gen_int_rv())} days") @@ -581,7 +594,8 @@ builder.configure_latent( Each observation process's `name` attribute becomes the keyword used to pass that observation's data to `model.run()` (e.g., `hospital={...}`, `wastewater={...}`). ```{python} -# | label: add-observations-build +#| label: add-observations-build + builder.add_observation(hosp_obs) # Uses hosp_obs.name = "hospital" builder.add_observation(ww_obs) # Uses ww_obs.name = "wastewater" model = builder.build() @@ -596,22 +610,22 @@ print(f" Observation processes: {list(model.observations.keys())}") ### Model Identifiability. The renewal equation is linear in the initial infections `I0_rv`, so scaling `I0_rv` by a factor $c$ scales the entire infection trajectory by $c$. -In practice, `I0_rv` is weakly identified because each observation process links infections to data through a signal-specific ascertainment rate $\alpha_s$ — the probability that an infection is observed as an event in signal $s$. +In practice, `I0_rv` is weakly identified because each observation process links infections to data through a signal-specific ascertainment rate $\alpha_s$ --- the probability that an infection is observed as an event in signal $s$. Doubling `I0_rv` while halving all ascertainment rates produces identical expected observations. Without external information to anchor either the ascertainment rates or the absolute infection level, the data cannot distinguish "more infections, lower ascertainment" from "fewer infections, higher ascertainment." The priors on `I0_rv` and on the ascertainment rates resolve this ambiguity. - ## Fitting the Model to Data: `model.run()` When you call `model.run()`, you supply two types of information: -- **Observation data** — one data dictionary per registered observation process -- **Population structure** — how the jurisdiction is divided into subpopulations +- **Observation data** --- one data dictionary per registered observation process +- **Population structure** --- how the jurisdiction is divided into subpopulations ### Shared Time Axis -All observation data uses a **shared time axis** `[0, n_total)` where `n_total = n_init + n_days`. This shared axis aligns observations with the internal infection vectors: +All observation data uses a **shared time axis** `[0, n_total)` where `n_total = n_init + n_days`. +This shared axis aligns observations with the internal infection vectors: - Index 0 corresponds to the first day of the initialization period - Index `n_init` corresponds to the first day of actual observations @@ -619,16 +633,16 @@ All observation data uses a **shared time axis** `[0, n_total)` where `n_total = The model provides helper methods to align your data with this shared axis: -- `model.pad_observations(obs)` — prepends `n_init` NaN values to dense observation vectors -- `model.shift_times(times)` — adds `n_init` to sparse time indices +- `model.pad_observations(obs)` --- prepends `n_init` NaN values to dense observation vectors +- `model.shift_times(times)` --- adds `n_init` to sparse time indices ### Observation Data by Signal Type Each observation process's `name` attribute becomes the keyword argument for passing data to `model.run()`: ```python -builder.add_observation(hosp_obs) # hosp_obs.name="hospital" → hospital={...} -builder.add_observation(ww_obs) # ww_obs.name="wastewater" → wastewater={...} +builder.add_observation(hosp_obs) # hosp_obs.name="hospital" → hospital={...} +builder.add_observation(ww_obs) # ww_obs.name="wastewater" → wastewater={...} ``` #### Jurisdiction-level signals (dense) @@ -636,37 +650,51 @@ builder.add_observation(ww_obs) # ww_obs.name="wastewater" → wastewater={. The jurisdiction-level hospital admissions data is specified as a `PopulationCounts` observations process with dense data padded to length `n_total`: ```python -hospital={ +hospital = { "obs": model.pad_observations(hosp_counts), # shape: (n_total,), NaN-padded } ``` -The `pad_observations` method prepends `n_init` NaN values. NaN marks the initialization period where predictions exist but observations do not. You can also use NaN to mark missing data within the observation period. +The `pad_observations` method prepends `n_init` NaN values. +NaN marks the initialization period where predictions exist but observations do not. +You can also use NaN to mark missing data within the observation period. #### Subpopulation-level signals (sparse) The subpopulation-level wastewater data is specified as a `Wastewater` observations process with sparse indexing on the shared time axis: ```python -wastewater={ - "obs": jnp.array([...]), # observed log concentrations (n_obs,) - "times": model.shift_times(ww_times), # time indices on shared axis - "subpop_indices": jnp.array([...]), # which subpopulation (selects infection column) - "sensor_indices": jnp.array([...]), # which WWTP/lab pair (selects noise parameters) - "n_sensors": int, # total number of WWTP/lab pairs +wastewater = { + "obs": jnp.array([...]), # observed log concentrations (n_obs,) + "times": model.shift_times(ww_times), # time indices on shared axis + "subpop_indices": jnp.array( + [...] + ), # which subpopulation (selects infection column) + "sensor_indices": jnp.array( + [...] + ), # which WWTP/lab pair (selects noise parameters) + "n_sensors": int, # total number of WWTP/lab pairs } ``` The `shift_times` method adds `n_init` to convert from natural coordinates (0 = first observation day) to the shared time axis. -**Understanding `subpop_indices`**: The latent process generates infections for all subpopulations as a matrix of shape `(T, n_subpops)`. Each observation selects which column (subpopulation) it came from using `subpop_indices`. This is how observation processes "know" which subpopulations they observe—the user specifies this mapping at sample/run time. +**Understanding `subpop_indices`**: The latent process generates infections for all subpopulations as a matrix of shape `(T, n_subpops)`. +Each observation selects which column (subpopulation) it came from using `subpop_indices`. +This is how observation processes "know" which subpopulations they observe---the user specifies this mapping at sample/run time. -A **subpopulation** is a portion of the jurisdiction's population (e.g., a catchment area). A **sensor** is a measurement source — typically a WWTP/lab pair — that produces observations. Multiple sensors can observe the same subpopulation (e.g., different labs processing samples from the same catchment), so `subpop_indices` and `sensor_indices` may differ. +A **subpopulation** is a portion of the jurisdiction's population (e.g., a catchment area). +A **sensor** is a measurement source --- typically a WWTP/lab pair --- that produces observations. +Multiple sensors can observe the same subpopulation (e.g., different labs processing samples from the same catchment), so `subpop_indices` and `sensor_indices` may differ. -- `subpop_indices` links each observation to the appropriate infection column (0-indexed into the subpopulations) -- `sensor_indices` selects which sensor's noise parameters (mode and sd) to apply +- `subpop_indices` links each observation to the appropriate infection column (0-indexed into the subpopulations) +- `sensor_indices` selects which sensor's noise parameters (mode and sd) to apply -**Example**: A jurisdiction has 6 subpopulations (indices 0-5), where 5 have wastewater monitoring and 1 does not. The `subpop_fractions` array has 6 elements. If subpopulation 2 lacks wastewater monitoring, wastewater observations would have `subpop_indices` values only in {0, 1, 3, 4, 5}—never 2. The monitored subpopulations need not be contiguous. The latent process still generates infections for all 6 subpopulations; the wastewater observation just doesn't see subpopulation 2. +**Example**: A jurisdiction has 6 subpopulations (indices 0-5), where 5 have wastewater monitoring and 1 does not. +The `subpop_fractions` array has 6 elements. +If subpopulation 2 lacks wastewater monitoring, wastewater observations would have `subpop_indices` values only in {0, 1, 3, 4, 5}---never 2. +The monitored subpopulations need not be contiguous. +The latent process still generates infections for all 6 subpopulations; the wastewater observation just doesn't see subpopulation 2. ### Population Structure @@ -679,10 +707,13 @@ model.run( ) ``` -This specifies 6 subpopulations with their population fractions. The fractions must sum to 1.0. The latent process generates infections for all 6 subpopulations. - -Which subpopulations each observation process "sees" is determined by the `subpop_indices` in the observation data, not by the population structure. For example, if wastewater monitoring covers only 5 of the 6 subpopulations (say, all except subpopulation 2), the wastewater observation data would have `subpop_indices` values in {0, 1, 3, 4, 5} but never 2. The monitored subpopulations can be any subset of {0, ..., n_subpops-1}. +This specifies 6 subpopulations with their population fractions. +The fractions must sum to 1.0. +The latent process generates infections for all 6 subpopulations. +Which subpopulations each observation process "sees" is determined by the `subpop_indices` in the observation data, not by the population structure. +For example, if wastewater monitoring covers only 5 of the 6 subpopulations (say, all except subpopulation 2), the wastewater observation data would have `subpop_indices` values in {0, 1, 3, 4, 5} but never 2. +The monitored subpopulations can be any subset of {0, ..., n_subpops-1}. ### Example `model.run()` Call @@ -700,8 +731,8 @@ model.run( samples = model.mcmc.get_samples() ``` -where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. -For this example, such a data dictionary would have the following structure: +where `**obs_data` is a data dictionary which supplies the observation start date as `obs_start_date` and a data dictionary for each set of signal data, where the name of the data dictionary corresponds to the signal name registered on the builder. +For this example, such a data dictionary would have the following structure: ```python { @@ -719,13 +750,15 @@ For this example, such a data dictionary would have the following structure: } ``` - ## Running the Model -First we declare the population structure. We have 6 subpopulations, where 5 have wastewater monitoring and 1 does not. The subpopulations with wastewater monitoring need not be contiguous indices—they could be any subset of {0, 1, ..., n_subpops-1}. +First we declare the population structure. +We have 6 subpopulations, where 5 have wastewater monitoring and 1 does not. +The subpopulations with wastewater monitoring need not be contiguous indices---they could be any subset of {0, 1, ..., n_subpops-1}. ```{python} -# | label: population-structure +#| label: population-structure + # All 6 subpopulations with their population fractions subpop_fractions = jnp.array([0.10, 0.14, 0.21, 0.22, 0.07, 0.26]) @@ -739,9 +772,7 @@ ww_monitored_subpops = jnp.array( ) # subpop 5 has no wastewater monitoring print(f"Total subpopulations: {n_subpops}") -print( - f"Subpopulations with wastewater monitoring: {list(ww_monitored_subpops)}" -) +print(f"Subpopulations with wastewater monitoring: {list(ww_monitored_subpops)}") print( f"Wastewater coverage: {float(jnp.sum(subpop_fractions[ww_monitored_subpops])):.0%}" ) @@ -753,7 +784,8 @@ The returned dictionary is structured to match the keyword arguments of `model.r At call time the returned dict is unpacked with `**` (for example, `**obs_data_90` in the 90-day fit below), forwarding each entry as a keyword argument to `model.run()`. ```{python} -# | label: prepare-observation-data +#| label: prepare-observation-data + def prepare_observation_data( model, n_days_fit, @@ -800,12 +832,9 @@ def prepare_observation_data( n_ww_sensors = ww_data["n_sites"] n_monitored = len(ww_monitored_subpops) sensor_to_subpop = { - i: int(ww_monitored_subpops[i % n_monitored]) - for i in range(n_ww_sensors) + i: int(ww_monitored_subpops[i % n_monitored]) for i in range(n_ww_sensors) } - ww_subpop_indices = jnp.array( - [sensor_to_subpop[int(s)] for s in ww_sensors] - ) + ww_subpop_indices = jnp.array([sensor_to_subpop[int(s)] for s in ww_sensors]) return { "obs_start_date": obs_start_date, @@ -828,7 +857,8 @@ Putting this altogether, we align the data with the model time and call `model.r We run 4 sampler chains. ```{python} -# | label: fit-90-days +#| label: fit-90-days + # Clear JAX caches to avoid interference from earlier cells jax.clear_caches() @@ -866,11 +896,11 @@ print(f"Elapsed time: {elapsed_90:.1f} seconds") We use [ArviZ](https://python.arviz.org/en/stable/) to assess MCMC convergence and mixing via the $\hat{R}$ statistic and effective sample size (ESS). Before running these diagnostics, it is necessary to we trim the first `n_init` time steps from all time-series variables. -Since the model cannot estimate latent infections until it has seen a full generation interval's worth of data, -these early time steps have no meaningful epidemiological interpretation and therefore should be excluded from summaries and visualizations. +Since the model cannot estimate latent infections until it has seen a full generation interval's worth of data, these early time steps have no meaningful epidemiological interpretation and therefore should be excluded from summaries and visualizations. ```{python} -# | label: helper-trim-time +#| label: helper-trim-time + def trim_time(ds): """Trim first n_init entries from time dimension and reindex.""" if "time" in ds.dims: @@ -884,7 +914,8 @@ Then we trim the first `n_init` time steps from all time-series variables. Finally we call the [az.summary](https://python.arviz.org/projects/stats/en/latest/api/generated/arviz_stats.summary.html) report. ```{python} -# | label: arviz-diagnostics-90 +#| label: arviz-diagnostics-90 + idata_90 = az.from_numpyro( model.mcmc, dims={ @@ -901,15 +932,13 @@ idata_90 = az.from_numpyro( ) idata_90_trimmed = idata_90.map_over_datasets(trim_time) -az.summary( - idata_90_trimmed, var_names=["latent_infections", "hospital_predicted"] -) +az.summary(idata_90_trimmed, var_names=["latent_infections", "hospital_predicted"]) ``` We extract the posterior quantiles and print summary statistics. ```{python} -# | label: extract-quantiles-90 +#| label: extract-quantiles-90 latent_inf = idata_90_trimmed.posterior["latent_infections"] @@ -925,11 +954,14 @@ print(f" Mean 90% CI width: {ci_width_90.mean():,.0f} infections") print(f" Median infections (day 45): {quantiles_90['q50'][45]:,.0f}") ``` -Finally, we visualize the posterior latent infections alongside observed hospitalizations. Note that **hospital admissions lag behind infections** by the infection-to-hospitalization delay (mode ~10 days in our delay PMF). When comparing the two panels, peaks in the infection curve should precede corresponding peaks in hospitalizations by roughly 10-14 days. +Finally, we visualize the posterior latent infections alongside observed hospitalizations. +Note that **hospital admissions lag behind infections** by the infection-to-hospitalization delay (mode \~10 days in our delay PMF). +When comparing the two panels, peaks in the infection curve should precede corresponding peaks in hospitalizations by roughly 10-14 days. ```{python} -# | label: fig-posterior-90 -# | fig-cap: Posterior latent infections and observed hospitalizations (90 days). +#| label: fig-posterior-90 +#| fig-cap: Posterior latent infections and observed hospitalizations (90 days). + # Visualize posterior latent infections and observed hospitalizations (90 days) # Create separate dataframes for faceted plot infections_df_90 = pd.DataFrame( @@ -944,9 +976,7 @@ infections_df_90 = pd.DataFrame( # Add 14-day moving average to smooth noisy daily admissions hosp_raw_90 = np.array(hosp_admits[:n_days_90], dtype=float) -hosp_ma_90 = ( - pd.Series(hosp_raw_90).rolling(window=14, center=True).mean().values -) +hosp_ma_90 = pd.Series(hosp_raw_90).rolling(window=14, center=True).mean().values hosp_df_90 = pd.DataFrame( { @@ -998,7 +1028,8 @@ plot_df_90["signal"] = pd.Categorical( ### Fit: 180 Days ```{python} -# | label: fit-180-days +#| label: fit-180-days + # Clear JAX caches to avoid interference jax.clear_caches() @@ -1036,7 +1067,8 @@ print(f"Elapsed time: {elapsed_180:.1f} seconds") We check the model fit, as before. ```{python} -# | label: arviz-diagnostics-180 +#| label: arviz-diagnostics-180 + # ArviZ diagnostics for 180-day fit idata_180 = az.from_numpyro( @@ -1055,13 +1087,11 @@ idata_180 = az.from_numpyro( ) idata_180_trimmed = idata_180.map_over_datasets(trim_time) -az.summary( - idata_180_trimmed, var_names=["latent_infections", "hospital_predicted"] -) +az.summary(idata_180_trimmed, var_names=["latent_infections", "hospital_predicted"]) ``` ```{python} -# | label: extract-quantiles-180 +#| label: extract-quantiles-180 latent_inf = idata_180_trimmed.posterior["latent_infections"] @@ -1078,8 +1108,10 @@ print(f" Median infections (day 90): {quantiles_180['q50'][90]:,.0f}") ``` ```{python} -# | label: fig-posterior-180 -# | fig-cap: Posterior latent infections and observed hospitalizations (180 days). +#| label: fig-posterior-180 +#| fig-cap: Posterior latent infections and observed hospitalizations (180 +#| days). + # Visualize posterior latent infections and observed hospitalizations (180 days) infections_df_180 = pd.DataFrame( { @@ -1093,9 +1125,7 @@ infections_df_180 = pd.DataFrame( # Add 14-day moving average to smooth noisy daily admissions hosp_raw_180 = np.array(hosp_admits[:n_days_180], dtype=float) -hosp_ma_180 = ( - pd.Series(hosp_raw_180).rolling(window=14, center=True).mean().values -) +hosp_ma_180 = pd.Series(hosp_raw_180).rolling(window=14, center=True).mean().values hosp_df_180 = pd.DataFrame( { @@ -1146,10 +1176,11 @@ plot_df_180["signal"] = pd.Categorical( ### Comparing 90-Day vs 180-Day Fits -Comparing the two fits reveals where uncertainty reduction occurs—and why it matters for forecasting. +Comparing the two fits reveals where uncertainty reduction occurs---and why it matters for forecasting. ```{python} -# | label: compare-fits +#| label: compare-fits + # Compare CI widths for the overlapping 90-day period ci_width_90_overlap = quantiles_90["q95"] - quantiles_90["q05"] ci_width_180_overlap = ( @@ -1161,12 +1192,8 @@ ci_diff = ci_width_90_overlap - ci_width_180_overlap ci_ratio = ci_width_90_overlap / ci_width_180_overlap print("CI Width Comparison (first 90 days):") -print( - f" 90-day fit mean CI width: {ci_width_90_overlap.mean():,.0f} infections" -) -print( - f" 180-day fit mean CI width: {ci_width_180_overlap.mean():,.0f} infections" -) +print(f" 90-day fit mean CI width: {ci_width_90_overlap.mean():,.0f} infections") +print(f" 180-day fit mean CI width: {ci_width_180_overlap.mean():,.0f} infections") print(f" Mean difference: {ci_diff.mean():,.0f} infections") print(f" Mean ratio (90/180): {ci_ratio.mean():.2f}x") print( @@ -1188,27 +1215,30 @@ for start, end, label in [ ) ``` -Notice that the uncertainty reduction is concentrated in days 60-90—the final month of the 90-day window. Earlier periods (days 0-60) show little change because both fits have sufficient future data to constrain those estimates. +Notice that the uncertainty reduction is concentrated in days 60-90---the final month of the 90-day window. +Earlier periods (days 0-60) show little change because both fits have sufficient future data to constrain those estimates. -This pattern has a direct implication for forecasting: **renewal models are most uncertain at the edge of the observation window**. Future observations constrain past latent infections through the renewal equation, but when predicting beyond available data, this constraint disappears. The high uncertainty in days 60-90 of the 90-day fit is exactly what we'd expect when forecasting 30 days ahead—there's no future signal to anchor the estimates. +This pattern has a direct implication for forecasting: **renewal models are most uncertain at the edge of the observation window**. +Future observations constrain past latent infections through the renewal equation, but when predicting beyond available data, this constraint disappears. +The high uncertainty in days 60-90 of the 90-day fit is exactly what we'd expect when forecasting 30 days ahead---there's no future signal to anchor the estimates. ## Summary This tutorial demonstrated composing a multi-signal renewal model using `PyrenewBuilder`: -1. **Configure latent process** (`configure_latent`): generation interval, initial infections, temporal dynamics -2. **Add observation processes** (`add_observation`): each declares its infection resolution and gets a name for data binding -3. **Build and run** (`build`, `model.run`): the model routes infections to observations based on resolution and runs NUTS inference +1. **Configure latent process** (`configure_latent`): generation interval, initial infections, temporal dynamics +2. **Add observation processes** (`add_observation`): each declares its infection resolution and gets a name for data binding +3. **Build and run** (`build`, `model.run`): the model routes infections to observations based on resolution and runs NUTS inference ### Key Concepts -- **Two-part structure**: Renewal models separate latent infection dynamics from observation processes -- **Infection resolution**: Observation processes declare whether they need aggregate or subpop-level infections -- **Data routing**: `PyrenewBuilder` automatically routes infection trajectories to the appropriate observation processes -- **Time alignment**: Observations must be offset by `n_initialization_points` to align with model time +- **Two-part structure**: Renewal models separate latent infection dynamics from observation processes +- **Infection resolution**: Observation processes declare whether they need aggregate or subpop-level infections +- **Data routing**: `PyrenewBuilder` automatically routes infection trajectories to the appropriate observation processes +- **Time alignment**: Observations must be offset by `n_initialization_points` to align with model time ### Next Steps -- Explore different temporal processes for $\mathcal{R}(t)$ in the [Latent Infections](latent_infections.md) and [Latent Subpopulation Infections](latent_subpopulation_infections.md) tutorials -- Learn about count-based observation models in [Observation Processes: Counts](observation_processes_counts.md) -- Learn about continuous measurement models in [Observation Processes: Measurements](observation_processes_measurements.md) +- Explore different temporal processes for $\mathcal{R}(t)$ in the [Latent Infections](latent_infections.md) and [Latent Subpopulation Infections](latent_subpopulation_infections.md) tutorials +- Learn about count-based observation models in [Observation Processes: Counts](observation_processes_counts.md) +- Learn about continuous measurement models in [Observation Processes: Measurements](observation_processes_measurements.md) diff --git a/docs/tutorials/day_of_week_effects.qmd b/docs/tutorials/day_of_week_effects.qmd index 147bfc8c..83429d17 100644 --- a/docs/tutorials/day_of_week_effects.qmd +++ b/docs/tutorials/day_of_week_effects.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -18,8 +18,9 @@ jupyter: --- ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import numpy as np import numpyro @@ -48,7 +49,9 @@ Ignoring this weekly periodicity forces the noise model to absorb systematic var PyRenew models day-of-week effects as a **multiplicative adjustment** applied to predicted counts after the delay convolution and ascertainment scaling: -$$\lambda(t) = d_{w(t)} \cdot \alpha \sum_{s} I(t-s)\,\pi(s)$$ +$$ +\lambda(t) = d_{w(t)} \cdot \alpha \sum_{s} I(t-s)\,\pi(s) +$$ where $d_{w(t)}$ is the day-of-week multiplier for the weekday of timepoint $t$, $\alpha$ is the ascertainment rate, and $\pi(s)$ is the delay PMF. The effect vector $\mathbf{d} = (d_0, d_1, \ldots, d_6)$ has one entry per day (0=Monday through 6=Sunday, ISO convention). @@ -60,7 +63,8 @@ When the effects sum to 7.0, the average daily multiplier is 1.0, preserving wee A typical pattern for ED visits might show weekday effects above 1.0 and weekend effects below 1.0: ```{python} -# | label: define-dow-effect +#| label: define-dow-effect + dow_values = jnp.array([1.20, 1.15, 1.10, 1.05, 1.00, 0.75, 0.75]) day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] @@ -69,11 +73,10 @@ print(f"Sum: {float(jnp.sum(dow_values)):.2f}") ``` ```{python} -# | label: plot-dow-effect +#| label: plot-dow-effect + dow_df = pd.DataFrame({"day": day_names, "effect": np.array(dow_values)}) -dow_df["day"] = pd.Categorical( - dow_df["day"], categories=day_names, ordered=True -) +dow_df["day"] = pd.Categorical(dow_df["day"], categories=day_names, ordered=True) ( p9.ggplot(dow_df, p9.aes(x="day", y="effect")) @@ -97,11 +100,10 @@ We construct two `PopulationCounts` observation processes using the same delay d The only difference is whether `day_of_week_rv` is provided. ```{python} -# | label: create-processes +#| label: create-processes + hosp_delay_pmf = jnp.array( - datasets.load_example_infection_admission_interval()[ - "probability_mass" - ].to_numpy() + datasets.load_example_infection_admission_interval()["probability_mass"].to_numpy() ) delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) ihr_rv = DeterministicVariable("ihr", 0.01) @@ -128,7 +130,8 @@ The `first_day_dow` parameter tells PyRenew which day of the week corresponds to Here we set `first_day_dow=0` (Monday). ```{python} -# | label: simulate-and-sample +#| label: simulate-and-sample + day_one = process_no_dow.lookback_days() n_total = 130 infections = 5000.0 * jnp.exp(0.03 * jnp.arange(n_total)) @@ -142,7 +145,8 @@ with numpyro.handlers.seed(rng_seed=0): ``` ```{python} -# | label: plot-predicted-comparison +#| label: plot-predicted-comparison + n_plot_days = n_total - day_one pred_rows = [] for i in range(n_plot_days): @@ -169,9 +173,7 @@ pred_df["type"] = pd.Categorical( ) ( - p9.ggplot( - pred_df, p9.aes(x="day", y="admissions", color="type", linetype="type") - ) + p9.ggplot(pred_df, p9.aes(x="day", y="admissions", color="type", linetype="type")) + p9.geom_line(size=1) + p9.scale_color_manual(values=["steelblue", "#e41a1c"]) + p9.scale_linetype_manual(values=["solid", "dashed"]) @@ -187,7 +189,7 @@ pred_df["type"] = pd.Categorical( ``` Without the day-of-week effect the predicted curve is smooth. -With it, the curve oscillates with a 7-day period — dipping on weekends and rising on weekdays — while following the same overall trend. +With it, the curve oscillates with a 7-day period --- dipping on weekends and rising on weekdays --- while following the same overall trend. ## Effect of the offset @@ -196,7 +198,8 @@ Changing it shifts which days receive which multiplier. Here we compare starting on Monday vs. Wednesday: ```{python} -# | label: offset-comparison +#| label: offset-comparison + with numpyro.handlers.seed(rng_seed=0): result_monday = process_with_dow.sample( infections=infections, obs=None, first_day_dow=0 @@ -208,7 +211,8 @@ with numpyro.handlers.seed(rng_seed=0): ``` ```{python} -# | label: plot-offset-comparison +#| label: plot-offset-comparison + offset_rows = [] for i in range(21): day_idx = day_one + i @@ -255,7 +259,7 @@ offset_df["offset"] = pd.Categorical( ``` The two curves have the same shape but are phase-shifted: their weekend dips fall on different days. -Getting `first_day_dow` right matters — a misaligned offset would attribute Monday's high to Sunday or vice versa. +Getting `first_day_dow` right matters --- a misaligned offset would attribute Monday's high to Sunday or vice versa. When using `MultiSignalModel`, pass `obs_start_date` (the date of the first observation day) to `model.sample()` or `model.run()`. The model handles the calendar bookkeeping and forwards the day-of-week information to every component that needs it. @@ -266,7 +270,8 @@ Day-of-week effects shape the noise draws, not just the predicted means. The noise model samples from a distribution centered on the adjusted predictions, so sampled observations inherit the weekly pattern. ```{python} -# | label: sample-noisy +#| label: sample-noisy + n_samples = 30 noisy_results = [] for seed in range(n_samples): @@ -297,15 +302,14 @@ for seed in range(n_samples): ``` ```{python} -# | label: plot-noisy +#| label: plot-noisy + noisy_df = pd.DataFrame(noisy_results) mean_df = noisy_df.groupby(["day", "type"])["admissions"].mean().reset_index() ( p9.ggplot(noisy_df, p9.aes(x="day", y="admissions")) - + p9.geom_line( - p9.aes(group="sample"), alpha=0.15, size=0.4, color="steelblue" - ) + + p9.geom_line(p9.aes(group="sample"), alpha=0.15, size=0.4, color="steelblue") + p9.geom_line( data=mean_df, mapping=p9.aes(x="day", y="admissions"), @@ -323,17 +327,20 @@ mean_df = noisy_df.groupby(["day", "type"])["admissions"].mean().reset_index() ``` The top panel shows smooth variation around the trend. -The bottom panel shows systematic weekly oscillation in both the mean (red) and individual samples (blue) — the weekend dips are visible even through the noise. +The bottom panel shows systematic weekly oscillation in both the mean (red) and individual samples (blue) --- the weekend dips are visible even through the noise. ## Composing with right-truncation Day-of-week effects and right-truncation are independent adjustments that compose naturally. Day-of-week is applied first (adjusting the expected counts for reporting patterns), then right-truncation scales down recent counts for incomplete reporting: -$$\lambda(t) = F(k_t) \cdot d_{w(t)} \cdot \alpha \sum_s I(t-s)\,\pi(s)$$ +$$ +\lambda(t) = F(k_t) \cdot d_{w(t)} \cdot \alpha \sum_s I(t-s)\,\pi(s) +$$ ```{python} -# | label: compose-with-truncation +#| label: compose-with-truncation + reporting_delay_pmf = jnp.array([0.4, 0.3, 0.15, 0.08, 0.04, 0.02, 0.01]) process_both = PopulationCounts( @@ -342,9 +349,7 @@ process_both = PopulationCounts( delay_distribution_rv=delay_rv, noise=NegativeBinomialNoise(concentration_rv), day_of_week_rv=DeterministicVariable("dow_effect", dow_values), - right_truncation_rv=DeterministicPMF( - "reporting_delay", reporting_delay_pmf - ), + right_truncation_rv=DeterministicPMF("reporting_delay", reporting_delay_pmf), ) with numpyro.handlers.seed(rng_seed=0): @@ -357,7 +362,8 @@ with numpyro.handlers.seed(rng_seed=0): ``` ```{python} -# | label: plot-composed +#| label: plot-composed + compose_rows = [] for i in range(n_plot_days): day_idx = day_one + i @@ -403,16 +409,16 @@ compose_df["type"] = pd.Categorical( The two curves agree in the early period. Near the right edge, right-truncation pulls the curve downward on top of the weekly oscillation. -Each adjustment operates on its own concern — weekly reporting patterns vs. incomplete recent data — and they combine multiplicatively without interfering. +Each adjustment operates on its own concern --- weekly reporting patterns vs. incomplete recent data --- and they combine multiplicatively without interfering. ## Summary Day-of-week adjustment is enabled by passing a `day_of_week_rv` at construction time and a `first_day_dow` at sample time. -| Parameter | Where | Purpose | -|-----------|-------|---------| -| `day_of_week_rv` | Constructor | 7-element multiplicative effect vector (0=Mon, 6=Sun) | -| `first_day_dow` | `sample()` | Day of the week for element 0 of the time axis | + | Parameter | Where | Purpose | + | ---------------- | ----------- | ----------------------------------------------------- | + | `day_of_week_rv` | Constructor | 7-element multiplicative effect vector (0=Mon, 6=Sun) | + | `first_day_dow` | `sample()` | Day of the week for element 0 of the time axis | When either is `None`, the adjustment is disabled and the process behaves identically to one without day-of-week effects. diff --git a/docs/tutorials/latent_infections.qmd b/docs/tutorials/latent_infections.qmd index b46a9a51..22d781ef 100644 --- a/docs/tutorials/latent_infections.qmd +++ b/docs/tutorials/latent_infections.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -18,8 +18,9 @@ jupyter: --- ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import jax.random as random import numpy as np @@ -37,7 +38,8 @@ from _tutorial_theme import theme_tutorial ``` ```{python} -# | label: imports +#| label: imports + from pyrenew.latent import ( PopulationInfections, AR1, @@ -50,7 +52,8 @@ from pyrenew.randomvariable import DistributionalVariable ## Overview -In infectious disease modeling, the true number of new infections at each time point is not directly observed. Instead, we observe indirect signals: hospital admissions, emergency department visits, reported cases, wastewater concentrations. +In infectious disease modeling, the true number of new infections at each time point is not directly observed. +Instead, we observe indirect signals: hospital admissions, emergency department visits, reported cases, wastewater concentrations. The latent infection process generates the unobserved infection trajectory that underlies all of these signals. PyRenew models latent infections using the **renewal equation**, which describes how new infections arise from recent past infections. @@ -71,8 +74,11 @@ Here, $\tau$ indexes lags in the generation interval. PyRenew provides two latent infection classes: -- **`PopulationInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. Appropriate when modeling one jurisdiction as a single population with one or more observation streams. This tutorial covers `PopulationInfections`. -- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. See [Latent Subpopulation Infections](latent_subpopulation_infections.md). +- **`PopulationInfections`**: A single $\mathcal{R}(t)$ drives one renewal equation. + Appropriate when modeling one jurisdiction as a single population with one or more observation streams. + This tutorial covers `PopulationInfections`. +- **`SubpopulationInfections`**: A baseline $\mathcal{R}(t)$ with per-subpopulation deviations. + See [Latent Subpopulation Infections](latent_subpopulation_infections.md). ## Model Inputs @@ -83,7 +89,10 @@ PyRenew provides two latent infection classes: 3. Value for $log(\mathcal{R}(t))$ at time $0$ (`log_rt_time_0_rv`) 4. A temporal process for $\mathcal{R}(t)$ dynamics (`single_rt_process`) -All inputs are **RandomVariables**, a quantity that is either known (observed, conditioned on) or unknown (to be inferred). See [PyRenew's RandomVariable abstract base class](random_variables.md). In this tutorial, we use `DeterministicVariable` and `DeterministicPMF` (fixed values) for illustration. In real inference, you would use `DistributionalVariable` with priors for quantities you want to estimate: +All inputs are **RandomVariables**, a quantity that is either known (observed, conditioned on) or unknown (to be inferred). +See [PyRenew's RandomVariable abstract base class](random_variables.md). +In this tutorial, we use `DeterministicVariable` and `DeterministicPMF` (fixed values) for illustration. +In real inference, you would use `DistributionalVariable` with priors for quantities you want to estimate: ```python # Fixed value (for illustration) @@ -95,10 +104,15 @@ I0_rv = DistributionalVariable("I0", dist.Beta(1, 1000)) ### Generation Interval -The generation interval is a key epidemiological input to the renewal equation. In many renewal-model analyses, it is specified from external studies rather than estimated jointly from the focal surveillance data. Because `R_t` estimates depend on this choice, sensitivity analysis is often warranted. For COVID-19, most transmission occurs within the first few days. We define a 7-day distribution for illustration: +The generation interval is a key epidemiological input to the renewal equation. +In many renewal-model analyses, it is specified from external studies rather than estimated jointly from the focal surveillance data. +Because `R_t` estimates depend on this choice, sensitivity analysis is often warranted. +For COVID-19, most transmission occurs within the first few days. +We define a 7-day distribution for illustration: ```{python} -# | label: generation-interval +#| label: generation-interval + gen_int_pmf = jnp.array([0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02]) gen_int_rv = DeterministicPMF("gen_int", gen_int_pmf) @@ -108,8 +122,10 @@ print(f"Generation interval: {len(gen_int_pmf)} days, mean = {mean_gi:.1f}") ``` ```{python} -# | label: fig-generation-interval -# | fig-cap: COVID-like generation interval distribution. The support is shown in whole days since infection of the primary case. +#| label: fig-generation-interval +#| fig-cap: COVID-like generation interval distribution. The support is shown in +#| whole days since infection of the primary case. + gi_df = pd.DataFrame({"day": days, "probability": np.array(gen_int_pmf)}) ( @@ -128,47 +144,35 @@ The generation interval length determines the minimum initialization period: wit ### Initial Conditions: `I0` and `log_rt_time_0` -These two parameters jointly define the infection history before the observation -period begins. Understanding their interaction requires knowing how the latent -infections process initializes the renewal equation. - -**The backprojection mechanism.** The renewal equation at time $0$ needs a -vector of recent infections to convolve with the generation interval. The latent -infections process constructs this initialization vector via exponential -backprojection. Let $\tau = 0, 1, \ldots, n_{\text{init}} - 1$ index positions -in the initialization vector, where $\tau = 0$ is the earliest time point and -$\tau = n_{\text{init}} - 1$ is the time point immediately before the -observation period begins. Then: - -$$I_{\text{init}}(\tau) = I_0 \cdot e^{r \cdot \tau}, \quad \tau = 0, 1, -\ldots, n_{\text{init}} - 1$$ - -where $r$ is the asymptotic growth rate implied by the reproduction number at -the start of the observation period, $\mathcal{R}(t=0) = -e^{\text{log\_rt\_time\_0}}$, and the generation interval. The function -`r_approx_from_R` converts $\mathcal{R}(t=0)$ and the generation interval into -$r$ using Newton's method. - -* **The level is set by `I0`**.
-`I0` is the infection prevalence at the earliest point in the initialization -period, $n_{\text{init}} - 1$ time points before $t = 0$. It sets the scale of -the entire initialization vector: $I_{\text{init}}(0) = I_0$, with subsequent -entries growing or declining exponentially toward $t = 0$. - -* **The shape is set by `log_rt_time_0`**.
-`log_rt_time_0` enters the model in two places: it is the starting point of -the $\mathcal{R}(t)$ trajectory ($\mathcal{R}(t=0) = -e^{\text{log\_rt\_time\_0}}$), and it determines the exponential growth rate -$r$ used to construct the initialization vector. When `log_rt_time_0 = 0`, -$r = 0$ and the initialization vector is flat at level `I0`. When -`log_rt_time_0 > 0`, infections are growing exponentially at $t = 0$; when -`log_rt_time_0 < 0`, they are declining. - - -The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. We can compute it directly for three values of `log_rt_time_0`: +These two parameters jointly define the infection history before the observation period begins. +Understanding their interaction requires knowing how the latent infections process initializes the renewal equation. + +**The backprojection mechanism.** The renewal equation at time $0$ needs a vector of recent infections to convolve with the generation interval. +The latent infections process constructs this initialization vector via exponential backprojection. +Let $\tau = 0, 1, \ldots, n_{\text{init}} - 1$ index positions in the initialization vector, where $\tau = 0$ is the earliest time point and $\tau = n_{\text{init}} - 1$ is the time point immediately before the observation period begins. +Then: + +$$ +I_{\text{init}}(\tau) = I_0 \cdot e^{r \cdot \tau}, \quad \tau = 0, 1, +\ldots, n_{\text{init}} - 1 +$$ + +where $r$ is the asymptotic growth rate implied by the reproduction number at the start of the observation period, $\mathcal{R}(t=0) = e^{\text{log\_rt\_time\_0}}$, and the generation interval. +The function `r_approx_from_R` converts $\mathcal{R}(t=0)$ and the generation interval into $r$ using Newton's method. + +- **The level is set by `I0`**.
`I0` is the infection prevalence at the earliest point in the initialization period, $n_{\text{init}} - 1$ time points before $t = 0$. + It sets the scale of the entire initialization vector: $I_{\text{init}}(0) = I_0$, with subsequent entries growing or declining exponentially toward $t = 0$. + +- **The shape is set by `log_rt_time_0`**.
`log_rt_time_0` enters the model in two places: it is the starting point of the $\mathcal{R}(t)$ trajectory ($\mathcal{R}(t=0) = e^{\text{log\_rt\_time\_0}}$), and it determines the exponential growth rate $r$ used to construct the initialization vector. + When `log_rt_time_0 = 0`, $r = 0$ and the initialization vector is flat at level `I0`. + When `log_rt_time_0 > 0`, infections are growing exponentially at $t = 0$; when `log_rt_time_0 < 0`, they are declining. + +The initialization vector is what the renewal equation "sees" as recent infection history at time $0$. +We can compute it directly for three values of `log_rt_time_0`: ```{python} -# | label: backprojection-compute +#| label: backprojection-compute + from pyrenew.math import r_approx_from_R n_init = len(gen_int_pmf) @@ -202,16 +206,19 @@ init_df = pd.DataFrame(init_data) ``` ```{python} -# | label: fig-backprojection -# | fig-cap: Initialization vectors for three values of log_rt_time_0. Days are numbered relative to day 0, which is when the temporal process and renewal equation take over. When log_rt_time_0 = 0 (stable), the vector is flat. Nonzero values produce exponential growth or decay in the pre-observation period. +#| label: fig-backprojection +#| fig-cap: Initialization vectors for three values of log_rt_time_0. Days are +#| numbered relative to day 0, which is when the temporal process and renewal +#| equation take over. When log_rt_time_0 = 0 (stable), the vector is flat. +#| Nonzero values produce exponential growth or decay in the pre-observation +#| period. + ( p9.ggplot(init_df, p9.aes(x="day", y="infections", color="config")) + p9.geom_line(size=1) + p9.geom_point(size=2) + p9.geom_vline(xintercept=-0.5, linetype="dashed", alpha=0.5) - + p9.annotate( - "text", x=-0.3, y=I0 * 1.15, label="day 0 -->", ha="right", size=10 - ) + + p9.annotate("text", x=-0.3, y=I0 * 1.15, label="day 0 -->", ha="right", size=10) + p9.labs( x="Day (relative to observation start)", y="Infections (proportion)", @@ -223,21 +230,28 @@ init_df = pd.DataFrame(init_data) ) ``` -The initialization vector matters because the renewal equation is a convolution: infections on day 0 depend on infections from days $-1$ through $-(K-1)$, weighted by the generation interval. A flat initialization (stable) means the renewal equation starts with uniform recent history. A growing initialization means the most recent days have disproportionately more infections, which amplifies the effect of the generation interval's short-lag weights. +The initialization vector matters because the renewal equation is a convolution: infections on day 0 depend on infections from days $-1$ through $-(K-1)$, weighted by the generation interval. +A flat initialization (stable) means the renewal equation starts with uniform recent history. +A growing initialization means the most recent days have disproportionately more infections, which amplifies the effect of the generation interval's short-lag weights. -After day 0, the temporal process takes over. How quickly the trajectory departs from its initial behavior depends on the temporal process choice and its hyperparameters, which we examine in the next section. +After day 0, the temporal process takes over. +How quickly the trajectory departs from its initial behavior depends on the temporal process choice and its hyperparameters, which we examine in the next section. ## Temporal Process Choice The temporal process governs how $\log \mathcal{R}(t)$ evolves day to day. -To evaluate what a given process implies, we use **prior predictive checks**: drawing many samples from the model *before seeing any data* and examining the distribution of trajectories. A single sample tells you little (the trajectory depends on the random seed), but the envelope of many samples reveals the structural constraints built into the process. -We fix the initial conditions to a growing epidemic (`log_rt_time_0 = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not. +To evaluate what a given process implies, we use **prior predictive checks**: drawing many samples from the model *before seeing any data* and examining the distribution of trajectories. +A single sample tells you little (the trajectory depends on the random seed), but the envelope of many samples reveals the structural constraints built into the process. +We fix the initial conditions to a growing epidemic (`log_rt_time_0 = 0.5`, so $\mathcal{R}(0) \approx 1.65$) with `I0 = 0.001`. +Starting well above equilibrium rather than near it makes the behavioral differences between temporal processes visible: the median trajectory of a mean-reverting process drifts back toward $\mathcal{R} = 1$, while a non-reverting process does not. -This section is primarily **modeling guidance for prior specification in PyRenew**, not a claim that epidemiologic theory uniquely determines one temporal process choice. The appropriate process depends on the scientific setting, time horizon, and how strongly you want the prior to regularize latent transmission dynamics. +This section is primarily **modeling guidance for prior specification in PyRenew**, not a claim that epidemiologic theory uniquely determines one temporal process choice. +The appropriate process depends on the scientific setting, time horizon, and how strongly you want the prior to regularize latent transmission dynamics. ```{python} -# | label: prior-predictive-config +#| label: prior-predictive-config + n_days = 28 n_init = len(gen_int_pmf) n_samples = 200 @@ -263,12 +277,10 @@ def sample_process(rt_process, label): samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42)) return { - "rt": np.array(samples["PopulationInfections::rt_single"])[ - :, n_init:, 0 + "rt": np.array(samples["PopulationInfections::rt_single"])[:, n_init:, 0], + "infections": np.array(samples["PopulationInfections::infections_aggregate"])[ + :, n_init: ], - "infections": np.array( - samples["PopulationInfections::infections_aggregate"] - )[:, n_init:], } @@ -278,9 +290,7 @@ def rt_spaghetti_plot(rt_array, title, color="steelblue"): spaghetti_data = [] for i in range(n_traj): for d in range(n_t): - spaghetti_data.append( - {"day": d, "rt": float(rt_array[i, d]), "sample": i} - ) + spaghetti_data.append({"day": d, "rt": float(rt_array[i, d]), "sample": i}) df = pd.DataFrame(spaghetti_data) summary_df = pd.DataFrame( { @@ -311,9 +321,7 @@ def rt_spaghetti_plot(rt_array, title, color="steelblue"): color=color, size=1.0, ) - + p9.geom_hline( - yintercept=1.0, color="red", linetype="dashed", alpha=0.7 - ) + + p9.geom_hline(yintercept=1.0, color="red", linetype="dashed", alpha=0.7) + p9.coord_cartesian(ylim=(0, rt_cap)) + p9.labs(x="Days", y="Rt", title=title) + theme_tutorial @@ -342,66 +350,82 @@ def rt_summary(rt_array, label): The simplest temporal process. -$$x_t = x_{t-1} + \varepsilon_t, \quad \varepsilon_t \sim N(0, \sigma)$$ +$$ +x_t = x_{t-1} + \varepsilon_t, \quad \varepsilon_t \sim N(0, \sigma) +$$ -There is no tendency to return to any particular value. Once $\mathcal{R}(t)$ drifts up or down, it stays there until noise pushes it elsewhere. The variance of $x_t$ grows linearly with time: $\text{Var}(x_t) = \sigma^2 t$. The further into the future, the less constrained the process is. +There is no tendency to return to any particular value. +Once $\mathcal{R}(t)$ drifts up or down, it stays there until noise pushes it elsewhere. +The variance of $x_t$ grows linearly with time: $\text{Var}(x_t) = \sigma^2 t$. +The further into the future, the less constrained the process is. -**Hyperparameter:** `innovation_sd` ($\sigma$) is the standard deviation of each daily step on the log scale. With `innovation_sd = 0.05`, each day's $\log \mathcal{R}$ changes by roughly $\pm 0.05$, which corresponds to roughly $\pm 5\%$ multiplicative change in $\mathcal{R}$. +**Hyperparameter:** `innovation_sd` ($\sigma$) is the standard deviation of each daily step on the log scale. +With `innovation_sd = 0.05`, each day's $\log \mathcal{R}$ changes by roughly $\pm 0.05$, which corresponds to roughly $\pm 5\%$ multiplicative change in $\mathcal{R}$. ```{python} -# | label: rw-sample +#| label: rw-sample + rw_samples = sample_process(RandomWalk(innovation_sd=0.05), "RandomWalk") ``` ```{python} -# | label: fig-rw -# | fig-cap: 'Prior predictive $\mathcal{R}(t)$ trajectories under a Random Walk (innovation_sd = 0.05). The envelope widens steadily over time because variance grows linearly. Trajectories drift in both directions from the starting value with no tendency to return.' +#| label: fig-rw +#| fig-cap: 'Prior predictive $\mathcal{R}(t)$ trajectories under a Random Walk (innovation_sd = 0.05). The envelope widens steadily over time because variance grows linearly. Trajectories drift in both directions from the starting value with no tendency to return.' + rt_spaghetti_plot(rw_samples["rt"], "Random Walk (innovation_sd = 0.05)") ``` ```{python} -# | label: rw-summary +#| label: rw-summary + rt_summary(rw_samples["rt"], "Random Walk (innovation_sd = 0.05)") ``` ### AR(1) -An autoregressive process of order 1. Each day, $\log \mathcal{R}(t)$ is pulled toward zero (i.e., $\mathcal{R} = 1$) by the autoregressive coefficient $\phi$: +An autoregressive process of order 1. +Each day, $\log \mathcal{R}(t)$ is pulled toward zero (i.e., $\mathcal{R} = 1$) by the autoregressive coefficient $\phi$: -$$x_t = \phi \, x_{t-1} + \varepsilon_t, \quad \varepsilon_t \sim N(0, \sigma)$$ +$$ +x_t = \phi \, x_{t-1} + \varepsilon_t, \quad \varepsilon_t \sim N(0, \sigma) +$$ -When $|\phi| < 1$, the process is **stationary**: its variance does not grow with time but is bounded at $\sigma^2 / (1 - \phi^2)$. This means the process "forgets" its initial value and fluctuates within a stable envelope. If $\mathcal{R}(t)$ drifts above 1, the $\phi$ coefficient pulls it back; if it drifts below, it pulls it up. +When $|\phi| < 1$, the process is **stationary**: its variance does not grow with time but is bounded at $\sigma^2 / (1 - \phi^2)$. +This means the process "forgets" its initial value and fluctuates within a stable envelope. +If $\mathcal{R}(t)$ drifts above 1, the $\phi$ coefficient pulls it back; if it drifts below, it pulls it up. **Hyperparameters:** -- `autoreg` ($\phi$): controls the strength and speed of mean reversion. Values near 1 produce slow reversion (the process remembers its recent past); values near 0 produce fast reversion (the process snaps back quickly). +- `autoreg` ($\phi$): controls the strength and speed of mean reversion. + Values near 1 produce slow reversion (the process remembers its recent past); values near 0 produce fast reversion (the process snaps back quickly). - `innovation_sd` ($\sigma$): standard deviation of daily noise. -The two hyperparameters jointly determine the **stationary standard deviation** $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, which is the long-run spread of $\log \mathcal{R}(t)$. For example, `autoreg = 0.9` and `innovation_sd = 0.05` give $\sigma_{\text{stat}} \approx 0.115$, meaning 95% of long-run $\log \mathcal{R}$ values fall within $\pm 0.23$ of zero, or equivalently $\mathcal{R} \in [0.79, 1.26]$. +The two hyperparameters jointly determine the **stationary standard deviation** $\sigma_{\text{stat}} = \sigma / \sqrt{1 - \phi^2}$, which is the long-run spread of $\log \mathcal{R}(t)$. +For example, `autoreg = 0.9` and `innovation_sd = 0.05` give $\sigma_{\text{stat}} \approx 0.115$, meaning 95% of long-run $\log \mathcal{R}$ values fall within $\pm 0.23$ of zero, or equivalently $\mathcal{R} \in [0.79, 1.26]$. ```{python} -# | label: ar1-sample +#| label: ar1-sample + ar1_samples = sample_process(AR1(autoreg=0.9, innovation_sd=0.05), "AR1") ``` ```{python} -# | label: fig-ar1 -# | fig-cap: 'Prior predictive $\mathcal{R}(t)$ trajectories under AR(1) (autoreg = 0.9, innovation_sd = 0.05). The envelope stabilizes rather than growing over time. The median trajectory (solid line) drifts from the starting $\mathcal{R}(0)$ of 1.65 back toward 1, illustrating mean reversion of the level.' -rt_spaghetti_plot( - ar1_samples["rt"], "AR(1) (autoreg = 0.9, innovation_sd = 0.05)" -) +#| label: fig-ar1 +#| fig-cap: 'Prior predictive $\mathcal{R}(t)$ trajectories under AR(1) (autoreg = 0.9, innovation_sd = 0.05). The envelope stabilizes rather than growing over time. The median trajectory (solid line) drifts from the starting $\mathcal{R}(0)$ of 1.65 back toward 1, illustrating mean reversion of the level.' + +rt_spaghetti_plot(ar1_samples["rt"], "AR(1) (autoreg = 0.9, innovation_sd = 0.05)") ``` ```{python} -# | label: ar1-summary +#| label: ar1-summary + rt_summary(ar1_samples["rt"], "AR(1) (autoreg = 0.9, innovation_sd = 0.05)") ``` - - ### DifferencedAR1 -DifferencedAR1 models *changes* in $\mathcal{R}(t)$ as an autoregressive process. In terms of increments, +DifferencedAR1 models *changes* in $\mathcal{R}(t)$ as an autoregressive process. +In terms of increments, $$ \Delta x_t = \texttt{autoreg} \cdot \Delta x_{t-1} + \varepsilon_t, \quad x_t = x_{t-1} + \Delta x_t, @@ -409,26 +433,32 @@ $$ so the autoregressive structure applies to the *rate of change* rather than the level. -This has an important consequence: while changes in $\mathcal{R}(t)$ are mean-reverting, the level itself is not. As a result, trajectories can exhibit sustained upward or downward trends and may drift over time, even when the innovations are small. -This process accumulates changes over time, so variability builds up in the level. Because successive changes are correlated, increases or decreases can persist for several steps in a row. These sustained runs can lead to wider trajectories than a random walk with the same `innovation_sd`, even though the individual innovations are small. +This has an important consequence: while changes in $\mathcal{R}(t)$ are mean-reverting, the level itself is not. +As a result, trajectories can exhibit sustained upward or downward trends and may drift over time, even when the innovations are small. +This process accumulates changes over time, so variability builds up in the level. +Because successive changes are correlated, increases or decreases can persist for several steps in a row. +These sustained runs can lead to wider trajectories than a random walk with the same `innovation_sd`, even though the individual innovations are small. **Hyperparameters:** -- `autoreg` ($\phi$): controls persistence of the rate of change. Values near 1 mean trends persist longer; values near 0 mean trends dissipate quickly. +- `autoreg` ($\phi$): controls persistence of the rate of change. + Values near 1 mean trends persist longer; values near 0 mean trends dissipate quickly. - `innovation_sd` ($\sigma$): standard deviation of shocks to the rate of change (not to the level). We use `innovation_sd = 0.01` (smaller than the 0.05 used for Random Walk and AR(1) above). ```{python} -# | label: dar1-sample +#| label: dar1-sample + dar1_samples = sample_process( DifferencedAR1(autoreg=0.5, innovation_sd=0.01), "DifferencedAR1" ) ``` ```{python} -# | label: fig-dar1 -# | fig-cap: 'DifferencedAR1 allows sustained directional trends because the increments are autocorrelated. Even with a smaller innovation scale than the other examples, the induced prior on $\mathcal{R}(t)$ can still be fairly diffuse over this time horizon.' +#| label: fig-dar1 +#| fig-cap: 'DifferencedAR1 allows sustained directional trends because the increments are autocorrelated. Even with a smaller innovation scale than the other examples, the induced prior on $\mathcal{R}(t)$ can still be fairly diffuse over this time horizon.' + rt_spaghetti_plot( dar1_samples["rt"], "DifferencedAR1 (autoreg = 0.5, innovation_sd = 0.01)", @@ -436,23 +466,26 @@ rt_spaghetti_plot( ``` ```{python} -# | label: dar1-summary -rt_summary( - dar1_samples["rt"], "DifferencedAR1 (autoreg = 0.5, innovation_sd = 0.01)" -) +#| label: dar1-summary + +rt_summary(dar1_samples["rt"], "DifferencedAR1 (autoreg = 0.5, innovation_sd = 0.01)") ``` ### Comparing the Three Processes -The plots above use different hyperparameter values for each process. This is intentional: the parameters `innovation_sd` and `autoreg` play different roles across processes (daily step size for RandomWalk, noise around the level for AR(1), and shocks to the *rate of change* for DifferencedAR1), so equal numerical values are not directly comparable. +The plots above use different hyperparameter values for each process. +This is intentional: the parameters `innovation_sd` and `autoreg` play different roles across processes (daily step size for RandomWalk, noise around the level for AR(1), and shocks to the *rate of change* for DifferencedAR1), so equal numerical values are not directly comparable. -For AR(1), `autoreg` controls how strongly $\mathcal{R}(t)$ is pulled back toward $1$, directly limiting variation in the level. For DifferencedAR1, `autoreg` instead controls how persistent changes are over time. Because these changes accumulate, even small innovations can produce substantial drift in $\mathcal{R}(t)$ when persistence is high. - -As a result, DifferencedAR1 can exhibit wider trajectories than might be expected from the scale of `innovation_sd` alone. Hyperparameters are chosen to illustrate characteristic behavior rather than to match marginal variability, so the resulting spreads are not directly comparable across panels. +For AR(1), `autoreg` controls how strongly $\mathcal{R}(t)$ is pulled back toward $1$, directly limiting variation in the level. +For DifferencedAR1, `autoreg` instead controls how persistent changes are over time. +Because these changes accumulate, even small innovations can produce substantial drift in $\mathcal{R}(t)$ when persistence is high. +As a result, DifferencedAR1 can exhibit wider trajectories than might be expected from the scale of `innovation_sd` alone. +Hyperparameters are chosen to illustrate characteristic behavior rather than to match marginal variability, so the resulting spreads are not directly comparable across panels. ```{python} -# | label: comparison-data +#| label: comparison-data + comparison_data = [] process_labels = { "Random Walk\n(sd=0.05)": rw_samples, @@ -480,8 +513,10 @@ comparison_df["process"] = pd.Categorical( ``` ```{python} -# | label: fig-comparison -# | fig-cap: Side-by-side comparison of all three temporal processes using median trajectories and 90% prior intervals. +#| label: fig-comparison +#| fig-cap: Side-by-side comparison of all three temporal processes using median +#| trajectories and 90% prior intervals. + ( p9.ggplot( comparison_df.groupby(["process", "day"])["rt"] @@ -512,8 +547,9 @@ comparison_df["process"] = pd.Categorical( ``` ```{python} -# | label: tbl-comparison-summary -# | tbl-cap: '$\mathcal{R}(t)$ at day 28 by temporal process.' +#| label: tbl-comparison-summary +#| tbl-cap: '$\mathcal{R}(t)$ at day 28 by temporal process.' + summary_rows = [] for label, ar, sd, s in [ ("Random Walk", None, 0.05, rw_samples), @@ -537,31 +573,43 @@ pd.DataFrame(summary_rows) **When to use which process:** -- **AR(1)** assumes $\mathcal{R}(t)$ reverts to $1$. Appropriate when modeling a pathogen near endemic equilibrium, or over short time horizons where large sustained departures from the current level are implausible. -- **Random Walk** assumes no preferred direction and no memory beyond the current value. Appropriate when you have little prior knowledge about whether $\mathcal{R}(t)$ will increase or decrease, and the data are dense enough to constrain the trajectory. -- **DifferencedAR1** assumes trends can persist but the rate of change stabilizes. Appropriate for epidemic waves where $\mathcal{R}(t)$ can sustain a direction of movement (e.g., rising during a wave, declining after interventions). +- **AR(1)** assumes $\mathcal{R}(t)$ reverts to $1$. + Appropriate when modeling a pathogen near endemic equilibrium, or over short time horizons where large sustained departures from the current level are implausible. +- **Random Walk** assumes no preferred direction and no memory beyond the current value. + Appropriate when you have little prior knowledge about whether $\mathcal{R}(t)$ will increase or decrease, and the data are dense enough to constrain the trajectory. +- **DifferencedAR1** assumes trends can persist but the rate of change stabilizes. + Appropriate for epidemic waves where $\mathcal{R}(t)$ can sustain a direction of movement (e.g., rising during a wave, declining after interventions). ### Effect of Hyperparameters The hyperparameters `autoreg` and `innovation_sd` control how tightly the prior constrains $\mathcal{R}(t)$, but their roles differ in important ways for DifferencedAR1.\ -The `innovation_sd` parameter sets the scale of shocks to the *rate of change* in $\mathcal{R}(t)$, while `autoreg` controls how persistent those changes are over time. In terms of increments, where $\phi =$ `autoreg` - +The `innovation_sd` parameter sets the scale of shocks to the *rate of change* in $\mathcal{R}(t)$, while `autoreg` controls how persistent those changes are over time. +In terms of increments, where $\phi =$ `autoreg` -$$\Delta x_t = \phi \cdot \Delta x_{t-1} + \varepsilon_t$$ +$$ +\Delta x_t = \phi \cdot \Delta x_{t-1} + \varepsilon_t +$$ -`autoreg` ($\phi$) determines how strongly each change carries forward into the next. The level then accumulates these changes via $x_t = x_{t-1} + \Delta x_t$. +`autoreg` ($\phi$) determines how strongly each change carries forward into the next. +The level then accumulates these changes via $x_t = x_{t-1} + \Delta x_t$. -When `autoreg` is large (e.g., 0.9), increments are highly persistent: a positive change in $\mathcal{R}(t)$ is likely to be followed by further positive changes. This produces sustained upward or downward trends, even when `innovation_sd` is small. Because these persistent changes accumulate in the level, the resulting prior over $\mathcal{R}(t)$ can be relatively diffuse over time. +When `autoreg` is large (e.g., 0.9), increments are highly persistent: a positive change in $\mathcal{R}(t)$ is likely to be followed by further positive changes. +This produces sustained upward or downward trends, even when `innovation_sd` is small. +Because these persistent changes accumulate in the level, the resulting prior over $\mathcal{R}(t)$ can be relatively diffuse over time. -Lowering `autoreg` reduces this persistence. Increments decorrelate more quickly, causing positive and negative changes to cancel out, which limits long runs of same-direction movement and produces a tighter prior over $\mathcal{R}(t)$. +Lowering `autoreg` reduces this persistence. +Increments decorrelate more quickly, causing positive and negative changes to cancel out, which limits long runs of same-direction movement and produces a tighter prior over $\mathcal{R}(t)$. -Importantly, this differs from AR(1): in AR(1), `autoreg` controls mean reversion of the *level*, directly shrinking $\mathcal{R}(t)$ toward its mean. In DifferencedAR1, only the *changes* are mean-reverting, so the level itself can drift. As a result, controlling long-term variability in DifferencedAR1 typically requires adjusting both `innovation_sd` (step size) and `autoreg` (persistence of trends). +Importantly, this differs from AR(1): in AR(1), `autoreg` controls mean reversion of the *level*, directly shrinking $\mathcal{R}(t)$ toward its mean. +In DifferencedAR1, only the *changes* are mean-reverting, so the level itself can drift. +As a result, controlling long-term variability in DifferencedAR1 typically requires adjusting both `innovation_sd` (step size) and `autoreg` (persistence of trends). We illustrate these effects below for DifferencedAR1 by varying both parameters: ```{python} -# | label: hyperparam-sample +#| label: hyperparam-sample + hyperparam_configs = { "autoreg=0.5, sd=0.01": DifferencedAR1(autoreg=0.5, innovation_sd=0.01), "autoreg=0.9, sd=0.02": DifferencedAR1(autoreg=0.9, innovation_sd=0.02), @@ -574,7 +622,8 @@ for name, rt_process in hyperparam_configs.items(): ``` ```{python} -# | label: hyperparam-rt-data +#| label: hyperparam-rt-data + hp_data = [] for name in hyperparam_configs: rt = hyperparam_samples[name]["rt"] @@ -592,8 +641,9 @@ hp_df["config"] = pd.Categorical( ``` ```{python} -# | label: fig-hyperparam-effect -# | fig-cap: 'Effect of DifferencedAR1 hyperparameters on prior $\mathcal{R}(t)$ distribution. Lower autoreg and innovation_sd (left) produce a tighter envelope; higher values (right) allow wider drift. All start at $\mathcal{R}(0)$ = 1.65.' +#| label: fig-hyperparam-effect +#| fig-cap: 'Effect of DifferencedAR1 hyperparameters on prior $\mathcal{R}(t)$ distribution. Lower autoreg and innovation_sd (left) produce a tighter envelope; higher values (right) allow wider drift. All start at $\mathcal{R}(0)$ = 1.65.' + ( p9.ggplot(hp_df, p9.aes(x="day", y="rt", group="sample")) + p9.geom_line(alpha=0.1, size=0.5, color="steelblue") @@ -610,8 +660,9 @@ hp_df["config"] = pd.Categorical( ``` ```{python} -# | label: tbl-hyperparam-summary -# | tbl-cap: '$\mathcal{R}(t)$ at day 28 by DifferencedAR1 configuration.' +#| label: tbl-hyperparam-summary +#| tbl-cap: '$\mathcal{R}(t)$ at day 28 by DifferencedAR1 configuration.' + hp_summary_rows = [] for name in hyperparam_configs: rt = hyperparam_samples[name]["rt"][:, -1] @@ -631,14 +682,18 @@ Both hyperparameters matter, and they interact: - **Lower `autoreg`** means faster reversion of the rate of change toward zero, producing trajectories that settle more quickly. - **Lower `innovation_sd`** means smaller daily shocks to the rate of change, producing smoother trends. -There are no universally correct hyperparameter values. The right choice depends on the pathogen, the time horizon, and the density of the available data. Prior predictive checks are the tool for evaluating whether a given configuration produces trajectories that are scientifically plausible for your setting. +There are no universally correct hyperparameter values. +The right choice depends on the pathogen, the time horizon, and the density of the available data. +Prior predictive checks are the tool for evaluating whether a given configuration produces trajectories that are scientifically plausible for your setting. ### Infection Trajectories -$\mathcal{R}(t)$ drives the renewal equation, but the infection trajectory $I(t)$ is what observation processes actually transform into data. The transformation from $\mathcal{R}(t)$ to $I(t)$ is nonlinear (growth compounds exponentially), so the distribution of infection trajectories has heavier tails and stronger skew than the distribution of $\mathcal{R}(t)$. +$\mathcal{R}(t)$ drives the renewal equation, but the infection trajectory $I(t)$ is what observation processes actually transform into data. +The transformation from $\mathcal{R}(t)$ to $I(t)$ is nonlinear (growth compounds exponentially), so the distribution of infection trajectories has heavier tails and stronger skew than the distribution of $\mathcal{R}(t)$. ```{python} -# | label: infection-comparison-data +#| label: infection-comparison-data + inf_comparison = [] for label, s in process_labels.items(): inf = s["infections"] @@ -661,12 +716,11 @@ inf_comparison_df["process"] = pd.Categorical( ``` ```{python} -# | label: fig-infection-comparison -# | fig-cap: 'Prior predictive infection trajectories (log scale). The nonlinear transformation from $\mathcal{R}(t)$ to infections amplifies differences between temporal processes. The vertical spread represents orders-of-magnitude differences in infection levels.' +#| label: fig-infection-comparison +#| fig-cap: 'Prior predictive infection trajectories (log scale). The nonlinear transformation from $\mathcal{R}(t)$ to infections amplifies differences between temporal processes. The vertical spread represents orders-of-magnitude differences in infection levels.' + ( - p9.ggplot( - inf_comparison_df, p9.aes(x="day", y="infections", group="sample") - ) + p9.ggplot(inf_comparison_df, p9.aes(x="day", y="infections", group="sample")) + p9.geom_line(alpha=0.1, size=0.4, color="darkorange") + p9.facet_wrap("~ process", ncol=3) + p9.scale_y_log10() @@ -681,7 +735,8 @@ inf_comparison_df["process"] = pd.Categorical( ## Connecting to Observation Processes -The latent infection trajectory is not observed directly. To connect it to data, one or more **observation processes** transform $I(t)$ into expected observations. +The latent infection trajectory is not observed directly. +To connect it to data, one or more **observation processes** transform $I(t)$ into expected observations. The `PyrenewBuilder` handles the wiring: diff --git a/docs/tutorials/latent_subpopulation_infections.qmd b/docs/tutorials/latent_subpopulation_infections.qmd index 0839764b..d5757477 100644 --- a/docs/tutorials/latent_subpopulation_infections.qmd +++ b/docs/tutorials/latent_subpopulation_infections.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -18,8 +18,9 @@ jupyter: --- ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import jax.random as random import numpy as np @@ -36,7 +37,8 @@ from _tutorial_theme import theme_tutorial ``` ```{python} -# | label: imports +#| label: imports + from pyrenew.latent import ( SubpopulationInfections, AR1, @@ -48,10 +50,10 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable ## Overview - `SubpopulationInfections` extends the renewal model to a population composed of $K$ subpopulations, each with its own latent infection trajectory. -As in the single-population case, infections evolve according to the renewal equation. For each subpopulation $k = 1, \dots, K$, we define +As in the single-population case, infections evolve according to the renewal equation. +For each subpopulation $k = 1, \dots, K$, we define $$ I_k(t) = \mathcal{R}_k(t) \sum_{\tau=1}^{G} I_k(t - \tau)\, w_\tau, @@ -89,7 +91,8 @@ Because deviations are additive on the log scale, they act multiplicatively on $ ### Aggregation -Each subpopulation produces its own infection trajectory $I_k(t)$. These are combined into an aggregate infection process using population fractions $p_k$, where $p_k$ is the fraction of the total population in subpopulation $k$, with +Each subpopulation produces its own infection trajectory $I_k(t)$. +These are combined into an aggregate infection process using population fractions $p_k$, where $p_k$ is the fraction of the total population in subpopulation $k$, with $$ \sum_{k=1}^{K} p_k = 1, \qquad p_k \ge 0. @@ -112,9 +115,10 @@ This model generalizes the single-population renewal model: When all $\delta_k(t) = 0$, the model reduces to the shared (single-population) case. -This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `log_rt_time_0`), and temporal processes. See [Latent Infections](latent_infections.md) for that background. +This tutorial assumes familiarity with the renewal equation, generation interval, initial conditions (`I0`, `log_rt_time_0`), and temporal processes. +See [Latent Infections](latent_infections.md) for that background. ---- +-------------------------------------------------------------------------------- ## Model Structure @@ -130,12 +134,10 @@ This tutorial assumes familiarity with the renewal equation, generation interval Two temporal processes define the evolution of $\log \mathcal{R}_k(t)$: -- **`baseline_rt_process`** - Temporal process for $\log \mathcal{R}_{\text{baseline}}(t)$. +- **`baseline_rt_process`** Temporal process for $\log \mathcal{R}_{\text{baseline}}(t)$. Produces a single trajectory shared across all subpopulations. -- **`subpop_rt_deviation_process`** - Temporal process for $\delta_k(t)$. +- **`subpop_rt_deviation_process`** Temporal process for $\delta_k(t)$. Produces $K$ trajectories (one per subpopulation), which are centered at each time point to satisfy the sum-to-zero constraint. Together, these define the full set of reproduction numbers: @@ -147,10 +149,10 @@ $$ \mathcal{R}_k(t) = \exp\big(\log \mathcal{R}_k(t)\big). $$ - ### Population structure -Population fractions $p_k$ are provided at **sample time**, not at model construction. This allows a single model specification to be reused across different jurisdictions or stratifications. +Population fractions $p_k$ are provided at **sample time**, not at model construction. +This allows a single model specification to be reused across different jurisdictions or stratifications. ### Random variables @@ -161,9 +163,9 @@ As in other PyRenew models, all inputs are specified as `RandomVariable`s: In practice, the temporal processes and initial conditions are typically given prior distributions, allowing the model to infer both shared and subpopulation-specific transmission dynamics. - ```{python} -# | label: model-setup +#| label: model-setup + gen_int_pmf = jnp.array([0.16, 0.32, 0.25, 0.14, 0.07, 0.04, 0.02]) gen_int_rv = DeterministicPMF("gen_int", gen_int_pmf) @@ -183,7 +185,8 @@ print(f"Log Rt at time 0: {np.exp(log_rt_time_0):.2f}") ``` ```{python} -# | label: instantiate +#| label: instantiate + model = SubpopulationInfections( name="SubpopulationInfections", gen_int_rv=gen_int_rv, @@ -197,13 +200,17 @@ model = SubpopulationInfections( ## The Sum-to-Zero Constraint -Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. `SubpopulationInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories. +Without a constraint, the baseline and deviations are not identifiable: shifting the baseline up by some amount $c$ and all deviations down by $c$ produces the same subpopulation $\mathcal{R}_k(t)$ values. +`SubpopulationInfections` enforces $\sum_k \delta_k(t) = 0$ at every time step by centering the raw deviation trajectories. This ensures $\mathcal{R}_{\text{baseline}}(t)$ is the *unweighted* geometric mean of the subpopulation $\mathcal{R}_k(t)$ values, so the baseline represents the typical transmission level across subpopulations. -Note that this is the unweighted mean across subpopulations, not population-weighted by $p_k$. As a result, $\mathcal{R}_{\text{baseline}}(t)$ does **not** in general equal the jurisdiction-level reproduction number implied by the aggregate infection trajectory $I_{\text{aggregate}}(t)$. When subpopulations differ in size, a small subpopulation with a large $\mathcal{R}_k(t)$ contributes equally to the baseline but only marginally to the aggregate. +Note that this is the unweighted mean across subpopulations, not population-weighted by $p_k$. +As a result, $\mathcal{R}_{\text{baseline}}(t)$ does **not** in general equal the jurisdiction-level reproduction number implied by the aggregate infection trajectory $I_{\text{aggregate}}(t)$. +When subpopulations differ in size, a small subpopulation with a large $\mathcal{R}_k(t)$ contributes equally to the baseline but only marginally to the aggregate. ```{python} -# | label: verify-constraint +#| label: verify-constraint + with numpyro.handlers.seed(rng_seed=42): with numpyro.handlers.trace() as trace: model.sample( @@ -224,32 +231,38 @@ print( The baseline process governs the jurisdiction-level trend in $\mathcal{R}(t)$. The same temporal process options apply as in `PopulationInfections` (see [Temporal Process Choice](latent_infections.md#temporal-process-choice)): AR(1) for mean reversion, DifferencedAR1 for persistent trends with stabilizing rate of change, RandomWalk for unconstrained drift. -We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `PopulationInfections` example, so we do not repeat that full comparison here. The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline. +We use DifferencedAR1 with small `innovation_sd` for the baseline throughout this tutorial. +The prior predictive for baseline $\mathcal{R}(t)$ behaves much like the corresponding `PopulationInfections` example, so we do not repeat that full comparison here. +The focus of this tutorial is the deviation process and how it changes the spread of subpopulation trajectories around the baseline. The same high-level decision rules apply here: -| If you believe... | Consider | -|-------------------|----------| -| Jurisdiction-wide $\mathcal{R}(t)$ should stay near a long-run level | `AR1` for the baseline | -| Jurisdiction-wide $\mathcal{R}(t)$ may drift over the modeled horizon | `DifferencedAR1` for the baseline | -| Local differences should fade back toward the baseline | `AR1` for deviations | -| Local differences can persist or accumulate | `RandomWalk` for deviations | + | If you believe... | Consider | + | --------------------------------------------------------------------- | --------------------------------- | + | Jurisdiction-wide $\mathcal{R}(t)$ should stay near a long-run level | `AR1` for the baseline | + | Jurisdiction-wide $\mathcal{R}(t)$ may drift over the modeled horizon | `DifferencedAR1` for the baseline | + | Local differences should fade back toward the baseline | `AR1` for deviations | + | Local differences can persist or accumulate | `RandomWalk` for deviations | Two parameters matter most when tuning these processes: -- **`autoreg`**: Controls how strongly AR(1)-type processes revert. Values near 1 imply slow reversion; smaller values imply faster pullback. -- **`innovation_sd`**: Controls day-to-day volatility. Larger values produce wider prior spreads and more abrupt movement. +- **`autoreg`**: Controls how strongly AR(1)-type processes revert. + Values near 1 imply slow reversion; smaller values imply faster pullback. +- **`innovation_sd`**: Controls day-to-day volatility. + Larger values produce wider prior spreads and more abrupt movement. ## Choosing the Deviation Temporal Process The deviation process controls how subpopulation $\mathcal{R}_k(t)$ values move relative to the baseline $\mathcal{R}_{\text{baseline}}(t)$, not the overall trend itself. -How $\delta_k(t)$ behaves at $t = 0$ depends on the process. `AR1` draws its initial value from the stationary distribution, so the prior spread of $\delta_k(0)$ already matches the stationary standard deviation and stays at that width throughout. `RandomWalk` starts at exactly $\delta_k(0) = 0$ and its spread grows over time. +How $\delta_k(t)$ behaves at $t = 0$ depends on the process. +`AR1` draws its initial value from the stationary distribution, so the prior spread of $\delta_k(0)$ already matches the stationary standard deviation and stays at that width throughout. +`RandomWalk` starts at exactly $\delta_k(0) = 0$ and its spread grows over time. -The key question: **are local differences transient or persistent?** -This determines whether subpopulations quickly return to the baseline or can diverge and remain different over time. +The key question: **are local differences transient or persistent?** This determines whether subpopulations quickly return to the baseline or can diverge and remain different over time. ```{python} -# | label: helpers +#| label: helpers + def sample_hierarchical(baseline_process, deviation_process, label): """Draw prior predictive samples from a SubpopulationInfections model.""" m = SubpopulationInfections( @@ -271,15 +284,15 @@ def sample_hierarchical(baseline_process, deviation_process, label): samples = Predictive(sampler, num_samples=n_samples)(random.PRNGKey(42)) return { - "rt_baseline": np.array( - samples["SubpopulationInfections::rt_baseline"] - )[:, n_init:, 0], + "rt_baseline": np.array(samples["SubpopulationInfections::rt_baseline"])[ + :, n_init:, 0 + ], "rt_subpop": np.array(samples["SubpopulationInfections::rt_subpop"])[ :, n_init:, : ], - "deviations": np.array( - samples["SubpopulationInfections::subpop_deviations"] - )[:, n_init:, :], + "deviations": np.array(samples["SubpopulationInfections::subpop_deviations"])[ + :, n_init:, : + ], "infections": np.array( samples["SubpopulationInfections::infections_aggregate"] )[:, n_init:], @@ -289,15 +302,19 @@ def sample_hierarchical(baseline_process, deviation_process, label): baseline_process = DifferencedAR1(autoreg=0.5, innovation_sd=0.01) ``` - ### AR(1) Deviations -AR(1) deviations have **bounded variance**. `AR1` draws its initial value from the stationary distribution of the process, so $\delta_k(0)$ is already dispersed at the stationary standard deviation $\sigma / \sqrt{1 - \phi^2}$ and the envelope stays at that width. With `autoreg = 0.8` and `innovation_sd = 0.05`, the stationary standard deviation is approximately $0.083$ on the log scale. +AR(1) deviations have **bounded variance**. +`AR1` draws its initial value from the stationary distribution of the process, so $\delta_k(0)$ is already dispersed at the stationary standard deviation $\sigma / \sqrt{1 - \phi^2}$ and the envelope stays at that width. +With `autoreg = 0.8` and `innovation_sd = 0.05`, the stationary standard deviation is approximately $0.083$ on the log scale. -The `autoreg` coefficient still matters: if a subpopulation drifts away from zero by chance, $\phi$ controls how quickly it is pulled back. Values near 1 produce slow pullback (local differences linger), values near 0 produce fast pullback (subpopulations snap back to the baseline). What this looks like in a prior predictive plot is a **constant-width band** rather than a fanning-out cloud. +The `autoreg` coefficient still matters: if a subpopulation drifts away from zero by chance, $\phi$ controls how quickly it is pulled back. +Values near 1 produce slow pullback (local differences linger), values near 0 produce fast pullback (subpopulations snap back to the baseline). +What this looks like in a prior predictive plot is a **constant-width band** rather than a fanning-out cloud. ```{python} -# | label: ar1-deviations-sample +#| label: ar1-deviations-sample + ar1_dev_samples = sample_hierarchical( baseline_process, AR1(autoreg=0.8, innovation_sd=0.05), @@ -306,7 +323,8 @@ ar1_dev_samples = sample_hierarchical( ``` ```{python} -# | label: ar1-deviations-data +#| label: ar1-deviations-data + def deviation_df(samples, label): """Build a long-format dataframe of deviation trajectories.""" devs = samples["deviations"] @@ -350,8 +368,9 @@ ar1_dev_summary = deviation_summary_df(ar1_dev_samples) ``` ```{python} -# | label: fig-ar1-deviations -# | fig-cap: 'AR(1) deviation trajectories (50 samples, all 6 subpopulations). Because `AR1` draws its initial value from the stationary distribution, the envelope is already at its stationary width at day 0 and stays bounded there. Compare the y-axis scale to fig-rw-deviations below.' +#| label: fig-ar1-deviations +#| fig-cap: "AR(1) deviation trajectories (50 samples, all 6 subpopulations). Because `AR1` draws its initial value from the stationary distribution, the envelope is already at its stationary width at day 0 and stays bounded there. Compare the y-axis scale to fig-rw-deviations below." + ( p9.ggplot() + p9.geom_line( @@ -387,10 +406,14 @@ ar1_dev_summary = deviation_summary_df(ar1_dev_samples) ### RandomWalk Deviations -RandomWalk deviations have **unbounded variance**. The spread of $\delta_k(t)$ grows linearly with time as $\sigma^2 t$, so over a 28-day horizon with `innovation_sd = 0.05` the standard deviation reaches roughly $0.265$ on the log scale — about three times the AR(1) stationary value above. There is no pullback toward zero: subpopulations that drift away from the baseline stay drifted, and differences that emerge early persist and can grow. The visual signature is a **fanning-out cloud** rather than a constant-width band. +RandomWalk deviations have **unbounded variance**. +The spread of $\delta_k(t)$ grows linearly with time as $\sigma^2 t$, so over a 28-day horizon with `innovation_sd = 0.05` the standard deviation reaches roughly $0.265$ on the log scale --- about three times the AR(1) stationary value above. +There is no pullback toward zero: subpopulations that drift away from the baseline stay drifted, and differences that emerge early persist and can grow. +The visual signature is a **fanning-out cloud** rather than a constant-width band. ```{python} -# | label: rw-deviations-sample +#| label: rw-deviations-sample + rw_dev_samples = sample_hierarchical( baseline_process, RandomWalk(innovation_sd=0.05), @@ -399,14 +422,16 @@ rw_dev_samples = sample_hierarchical( ``` ```{python} -# | label: rw-deviations-data +#| label: rw-deviations-data + rw_dev_df = deviation_df(rw_dev_samples, "RandomWalk deviations") rw_dev_summary = deviation_summary_df(rw_dev_samples) ``` ```{python} -# | label: fig-rw-deviations -# | fig-cap: 'RandomWalk deviation trajectories (50 samples, 6 subpopulations). The envelope fans out continuously over the 28-day horizon and reaches roughly three times the AR(1) stationary width. The y-axis matches fig-ar1-deviations above for direct comparison.' +#| label: fig-rw-deviations +#| fig-cap: "RandomWalk deviation trajectories (50 samples, 6 subpopulations). The envelope fans out continuously over the 28-day horizon and reaches roughly three times the AR(1) stationary width. The y-axis matches fig-ar1-deviations above for direct comparison." + ( p9.ggplot() + p9.geom_line( @@ -443,8 +468,9 @@ rw_dev_summary = deviation_summary_df(rw_dev_samples) ### Comparing Deviation Processes ```{python} -# | label: tbl-deviation-summary -# | tbl-cap: Deviation spread at day 28 (across all subpopulations and samples). +#| label: tbl-deviation-summary +#| tbl-cap: Deviation spread at day 28 (across all subpopulations and samples). + deviation_summary_rows = [] for label, s in [("AR(1)", ar1_dev_samples), ("RandomWalk", rw_dev_samples)]: devs = np.abs(s["deviations"][:, -1, :]).flatten() @@ -459,17 +485,25 @@ for label, s in [("AR(1)", ar1_dev_samples), ("RandomWalk", rw_dev_samples)]: pd.DataFrame(deviation_summary_rows) ``` -AR(1) deviations stay close to zero because mean reversion continuously pulls them back. RandomWalk deviations accumulate over time. As in the PopulationInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. The choice depends on the epidemiological setting: +AR(1) deviations stay close to zero because mean reversion continuously pulls them back. +RandomWalk deviations accumulate over time. +As in the PopulationInfections tutorial, this is prior-modeling guidance rather than a uniquely determined epidemiologic rule. +The choice depends on the epidemiological setting: -- **AR(1) deviations** when subpopulations are expected to track the jurisdiction average. Local outbreaks or lulls are temporary. This is typical for geographically close subpopulations (e.g., counties within a metropolitan area) where mobility mixes transmission. -- **RandomWalk deviations** when local differences can persist. This may be appropriate for subpopulations with distinct contact patterns, demographics, or intervention histories (e.g., urban vs. rural areas). +- **AR(1) deviations** when subpopulations are expected to track the jurisdiction average. + Local outbreaks or lulls are temporary. + This is typical for geographically close subpopulations (e.g., counties within a metropolitan area) where mobility mixes transmission. +- **RandomWalk deviations** when local differences can persist. + This may be appropriate for subpopulations with distinct contact patterns, demographics, or intervention histories (e.g., urban vs. rural areas). ## Baseline and Deviation Pairs -The baseline and deviation processes interact to determine the full prior over subpopulation $\mathcal{R}_k(t)$ trajectories. We compare two configurations: DifferencedAR1 baseline with AR(1) deviations, and DifferencedAR1 baseline with RandomWalk deviations. +The baseline and deviation processes interact to determine the full prior over subpopulation $\mathcal{R}_k(t)$ trajectories. +We compare two configurations: DifferencedAR1 baseline with AR(1) deviations, and DifferencedAR1 baseline with RandomWalk deviations. ```{python} -# | label: pairs-rt-data +#| label: pairs-rt-data + def subpop_rt_df(samples, label): """Build a long-format dataframe of subpopulation Rt trajectories.""" rt = samples["rt_subpop"] @@ -508,8 +542,9 @@ pairs_df = pd.concat( ``` ```{python} -# | label: fig-pairs-baseline -# | fig-cap: 'Baseline $\mathcal{R}(t)$ trajectories are identical for both configurations (same process, same seed). The difference is in how subpopulations spread around this baseline.' +#| label: fig-pairs-baseline +#| fig-cap: 'Baseline $\mathcal{R}(t)$ trajectories are identical for both configurations (same process, same seed). The difference is in how subpopulations spread around this baseline.' + baseline_df = pairs_df[pairs_df["subpop"] == "baseline"] ( p9.ggplot(baseline_df, p9.aes(x="day", y="rt", group="sample")) @@ -526,7 +561,8 @@ baseline_df = pairs_df[pairs_df["subpop"] == "baseline"] ``` ```{python} -# | label: pairs-subpop-summary-data +#| label: pairs-subpop-summary-data + def baseline_vs_subpop_summary(samples, subpop_idx, label): """Summarize baseline and one subpopulation's Rt across all prior draws.""" rt_base = samples["rt_baseline"] @@ -580,8 +616,9 @@ subpop_summary = pairs_summary_df[pairs_summary_df["series"] == "subpop 0"] ``` ```{python} -# | label: fig-pairs-subpop -# | fig-cap: 'Prior predictive bands for subpopulation 0''s $\mathcal{R}(t)$ (blue) overlaid on the shared baseline $\mathcal{R}_{\text{baseline}}(t)$ (black), under two deviation process choices. Each panel shows 50 subpopulation 0 trajectories (light blue lines), the 5–95% prior interval for subpopulation 0 (blue ribbon) and for the baseline (grey ribbon), and their medians (solid lines). The y-axis is log-scaled so that additive log-deviations appear as constant multiplicative widths. With AR(1) deviations (left), the subpopulation band has nearly the same width as the baseline band at every day, because the deviation variance is stationary. With RandomWalk deviations (right), the subpopulation band widens relative to the baseline band as the horizon grows, because deviation variance accumulates.' +#| label: fig-pairs-subpop +#| fig-cap: 'Prior predictive bands for subpopulation 0''s $\mathcal{R}(t)$ (blue) overlaid on the shared baseline $\mathcal{R}_{\text{baseline}}(t)$ (black), under two deviation process choices. Each panel shows 50 subpopulation 0 trajectories (light blue lines), the 5–95% prior interval for subpopulation 0 (blue ribbon) and for the baseline (grey ribbon), and their medians (solid lines). The y-axis is log-scaled so that additive log-deviations appear as constant multiplicative widths. With AR(1) deviations (left), the subpopulation band has nearly the same width as the baseline band at every day, because the deviation variance is stationary. With RandomWalk deviations (right), the subpopulation band widens relative to the baseline band as the horizon grows, because deviation variance accumulates.' + ( p9.ggplot() + p9.geom_ribbon( @@ -628,8 +665,9 @@ subpop_summary = pairs_summary_df[pairs_summary_df["series"] == "subpop 0"] ``` ```{python} -# | label: fig-pairs-all-subpops -# | fig-cap: 'All 6 subpopulation $\mathcal{R}(t)$ trajectories from a single prior draw, under both configurations. AR(1) deviations (left) keep subpopulations tightly bundled; RandomWalk deviations (right) allow them to spread apart.' +#| label: fig-pairs-all-subpops +#| fig-cap: 'All 6 subpopulation $\mathcal{R}(t)$ trajectories from a single prior draw, under both configurations. AR(1) deviations (left) keep subpopulations tightly bundled; RandomWalk deviations (right) allow them to spread apart.' + single_draw_data = [] for label, s in [ ("DifferencedAR1 +\nAR(1) deviations", ar1_dev_samples), @@ -692,7 +730,8 @@ single_draw_df["config"] = pd.Categorical( ## Connecting to Observation Processes -The latent infection trajectory is not observed directly. Each observation process selects a subset of subpopulations (via `subpop_indices`) and applies its own ascertainment rate and delay distribution. +The latent infection trajectory is not observed directly. +Each observation process selects a subset of subpopulations (via `subpop_indices`) and applies its own ascertainment rate and delay distribution. The `PyrenewBuilder` handles the wiring: diff --git a/docs/tutorials/observation_processes_counts.qmd b/docs/tutorials/observation_processes_counts.qmd index bb69e895..b27cf46f 100644 --- a/docs/tutorials/observation_processes_counts.qmd +++ b/docs/tutorials/observation_processes_counts.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -20,8 +20,9 @@ jupyter: This tutorial demonstrates how to use the `PopulationCounts` observation process to model count data such as hospital admissions, emergency department visits, or deaths. ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import numpy as np import numpyro @@ -79,28 +80,31 @@ This equation is a discrete convolution: observations at time $t$ arise from inf The observation equation defines the expected number of observed events, but the actual data are stochastic. -Let $Y(t)$ denote the observed number of events at time $t$. Observations are modeled as draws from a count distribution with central value (typically but not necessary mean) $\mu(t)$: +Let $Y(t)$ denote the observed number of events at time $t$. +Observations are modeled as draws from a count distribution with central value (typically but not necessary mean) $\mu(t)$: $$ Y(t) \sim \text{Distribution}(\mu(t), \theta). $$ One possible choice is the Poisson distribution, which assumes the variance equals the mean. -In practice, epidemiological count data are often overdispersed relative to the Poisson. Negative binomial distributions are a common choice for modeling these overdispersed counts. - +In practice, epidemiological count data are often overdispersed relative to the Poisson. +Negative binomial distributions are a common choice for modeling these overdispersed counts. This yields a two-layer model: - A **mechanistic layer**, where the delay convolution determines the predicted number of observations $\mu(t)$ from latent infections $I(t)$. - A **stochastic observation layer**, where observed counts $Y(t)$ vary around $\mu(t)$ according to a specified distribution. -**Note on terminology:** In real-world inference, incident infections $I(t)$ are typically *latent* (unobserved) and must be inferred from observed data. In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce observed counts through convolution and sampling. +**Note on terminology:** In real-world inference, incident infections $I(t)$ are typically *latent* (unobserved) and must be inferred from observed data. +In this tutorial, we simulate the observation process by specifying infections directly and showing how they produce observed counts through convolution and sampling. ## Hospital admissions example For hospital admissions data, we construct a `PopulationCounts` observation process. -The delay is the key mechanism: infections from $d$ days ago ($I(t-d)$) contribute to today's predicted hospital admissions ($\mu(t)$), weighted by the probability $\pi_d$ that an infection leads to hospitalization after exactly $d$ days. The convolution sums these contributions across all past days. +The delay is the key mechanism: infections from $d$ days ago ($I(t-d)$) contribute to today's predicted hospital admissions ($\mu(t)$), weighted by the probability $\pi_d$ that an infection leads to hospitalization after exactly $d$ days. +The convolution sums these contributions across all past days. Observed hospital admissions are then generated by sampling from a negative binomial distribution: @@ -125,10 +129,12 @@ In this example, we use fixed parameter values for illustration; in practice, th The delay distribution specifies the probability that an infected person is hospitalized $d$ days after infection, conditional on the infection leading to a hospitalization. For example, if `hosp_delay_pmf[5] = 0.2`, then 20% of infections that result in hospitalization will appear as hospital admissions 5 days after infection. -We load a delay distribution from PyRenew's example datasets, compute summary statistics, and plot it. The distribution peaks around day 8-9 post-infection. +We load a delay distribution from PyRenew's example datasets, compute summary statistics, and plot it. +The distribution peaks around day 8-9 post-infection. ```{python} -# | label: delay-distribution +#| label: delay-distribution + inf_hosp_int = datasets.load_example_infection_admission_interval() hosp_delay_pmf = jnp.array(inf_hosp_int["probability_mass"].to_numpy()) delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) @@ -138,26 +144,19 @@ days = np.arange(len(hosp_delay_pmf)) mean_delay = float(np.sum(days * hosp_delay_pmf)) mode_delay = int(np.argmax(hosp_delay_pmf)) sd = float(np.sqrt(np.sum(days**2 * hosp_delay_pmf) - mean_delay**2)) -print( - f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}" -) +print(f"mode delay: {mode_delay}, mean delay: {mean_delay:.1f}, sd: {sd:.1f}") ``` ```{python} -# | label: plot-delay-distribution -delay_df = pd.DataFrame( - {"days": days, "probability": np.array(hosp_delay_pmf)} -) +#| label: plot-delay-distribution + +delay_df = pd.DataFrame({"days": days, "probability": np.array(hosp_delay_pmf)}) plot_delay = ( p9.ggplot(delay_df, p9.aes(x="days", y="probability")) + p9.geom_col(fill="steelblue", alpha=0.7, color="black") - + p9.geom_vline( - xintercept=mode_delay, color="purple", linetype="solid", size=1 - ) - + p9.geom_vline( - xintercept=mean_delay, color="red", linetype="dashed", size=1 - ) + + p9.geom_vline(xintercept=mode_delay, color="purple", linetype="solid", size=1) + + p9.geom_vline(xintercept=mean_delay, color="red", linetype="dashed", size=1) + p9.labs( x="Days from infection to hospitalization", y="Probability", @@ -193,16 +192,17 @@ A `PopulationCounts` object takes the following arguments: - **`delay_distribution_rv`**: delay distribution from infection to observation (PMF) - **`noise`**: noise model (`PoissonNoise()` or `NegativeBinomialNoise(concentration_rv)`) -Observation processes are components in multi-signal models, where each signal must have a unique name. This name prefixes all numpyro sample sites (e.g., `"hospital"` creates sites `"hospital_obs"` and `"hospital_predicted"`), ensuring distinct identifiers in the inference trace. +Observation processes are components in multi-signal models, where each signal must have a unique name. +This name prefixes all numpyro sample sites (e.g., `"hospital"` creates sites `"hospital_obs"` and `"hospital_predicted"`), ensuring distinct identifiers in the inference trace. For hospital admissions, the ascertainment rate is specifically called the infection-hospitalization rate (IHR). -In this example, the percentage of infections which lead to hospitalization is treated as a fixed value, -which will allow us to see how different values affect the model. +In this example, the percentage of infections which lead to hospitalization is treated as a fixed value, which will allow us to see how different values affect the model. The concentration parameter for the negative binomial noise model is also fixed. In practice, both of these parameters would be given a somewhat informative prior and then inferred. ```{python} -# | label: create-counts-process +#| label: create-counts-process + # Infection-hospitalization ratio (1% of infections lead to hospitalization) ihr_rv = DeterministicVariable("ihr", 0.01) @@ -222,10 +222,14 @@ hosp_process = PopulationCounts( The observation process convolves infections with a delay distribution, maintaining alignment between input and output: day $t$ in the output corresponds to day $t$ in the input. -Hospital admissions depend on infections from prior days. A delay PMF of length $L$ covers delays 0 to $L-1$, requiring $L-1$ days of prior infection history. The method `lookback_days()` returns $L-1$; the first valid observation day is at index `lookback_days()`. Earlier days are marked invalid. +Hospital admissions depend on infections from prior days. +A delay PMF of length $L$ covers delays 0 to $L-1$, requiring $L-1$ days of prior infection history. +The method `lookback_days()` returns $L-1$; the first valid observation day is at index `lookback_days()`. +Earlier days are marked invalid. ```{python} -# | label: helper-function +#| label: helper-function + print(f"Required lookback: {hosp_process.lookback_days()} days") @@ -238,9 +242,9 @@ def first_valid_observation_day(obs_process) -> int: To demonstrate how a `PopulationCounts` observation process works, we examine how infections occurring on a single day result in observed hospital admissions. - ```{python} -# | label: simulate-spike +#| label: simulate-spike + n_days = 100 day_one = first_valid_observation_day(hosp_process) @@ -251,8 +255,10 @@ infections = infections.at[infection_spike_day].set(2000) ``` We plot the infections starting from day_one (the first valid observation day, after the lookback period). + ```{python} -# | label: plot-infections +#| label: plot-infections + # Plot relative to first valid observation day n_plot_days = n_days - day_one rel_spike_day = infection_spike_day - day_one @@ -297,9 +303,9 @@ Because all infections occur on a single day, this example shows how a single pu First, we compute the predicted admissions from the convolution alone, without observation noise. This gives the predicted number of observations $\mu(t)$. - ```{python} -# | label: predicted-no-noise +#| label: predicted-no-noise + # Compute predicted admissions (convolution only, no observation noise) from pyrenew.convolve import compute_delay_ascertained_incidence @@ -314,10 +320,12 @@ predicted_admissions, offset = compute_delay_ascertained_incidence( ) ``` -*Note:* in the above implementation, the ascertainment rate is applied by scaling infections before convolution. This is equivalent to applying $\alpha$ after the convolution in the observation equation. +*Note:* in the above implementation, the ascertainment rate is applied by scaling infections before convolution. +This is equivalent to applying $\alpha$ after the convolution in the observation equation. ```{python} -# | label: plot-predicted-no-noise +#| label: plot-predicted-no-noise + # Relative peak day for plotting peak_day = rel_spike_day + mode_delay @@ -382,7 +390,8 @@ The negative binomial distribution adds stochastic variation around $\mu(t)$, co Sampling multiple times from the same infections shows the range of possible observations: ```{python} -# | label: sample-realizations +#| label: sample-realizations + # Sample 50 realizations of hospital admissions from the same infection spike n_samples = 50 samples_list = [] @@ -415,7 +424,8 @@ for i, val in enumerate(predicted_admissions[day_one:]): ``` ```{python} -# | label: plot-realizations +#| label: plot-realizations + samples_df = pd.DataFrame(samples_list) sampled_df = samples_df[samples_df["type"] == "sampled"] predicted_noise_df = samples_df[samples_df["type"] == "predicted"] @@ -464,7 +474,8 @@ plot_50_samples ``` ```{python} -# | label: timeline-stats +#| label: timeline-stats + # Print timeline statistics print("Timeline Analysis:") print( @@ -482,7 +493,8 @@ The ascertainment rate (here, the infection-hospitalization rate or IHR) directl We compare two contrasting IHR values: **0.5%** and **2.5%**. ```{python} -# | label: compare-ihr +#| label: compare-ihr + # Two contrasting IHR values ihr_values = [0.005, 0.025] peak_value = 3000 # Peak infections @@ -509,9 +521,9 @@ for ihr_val in ihr_values: ) ``` - ```{python} -# | label: plot-ihr-comparisons +#| label: plot-ihr-comparisons + results_df = pd.DataFrame(results_list) plot_ihr = ( @@ -538,14 +550,15 @@ The concentration parameter $\phi$ controls overdispersion: We compare three concentration values spanning two orders of magnitude: -- **$\phi$ = 1**: high overdispersion (noisy) -- **$\phi$ = 10**: moderate overdispersion -- **$\phi$ = 100**: nearly Poisson (minimal noise) +- **$\phi$= 1**: high overdispersion (noisy) +- **$\phi$= 10**: moderate overdispersion +- **$\phi$= 100**: nearly Poisson (minimal noise) We hold daily infections constant over time so that any variation in the observed counts comes entirely from the observation model. ```{python} -# | label: concentration-comparisons +#| label: concentration-comparisons + # Use constant infections peak_value = 2000 infections_constant = peak_value * jnp.ones(n_days) @@ -585,7 +598,8 @@ for conc_val in concentration_values: ``` ```{python} -# | label: plot-concentration-comparisons +#| label: plot-concentration-comparisons + conc_df = pd.DataFrame(conc_results) # Convert to ordered categorical @@ -614,7 +628,8 @@ plot_concentration To use Poisson noise instead of negative binomial, change the noise model: ```{python} -# | label: poisson-noise +#| label: poisson-noise + hosp_process_poisson = PopulationCounts( name="hospital", ascertainment_rate_rv=ihr_rv, @@ -633,12 +648,14 @@ print( ) ``` -We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. +We can visualize the Poisson noise model using the same constant infection scenario as the concentration comparison above. +Since Poisson assumes variance equals the mean, it produces less variability than the negative binomial with low concentration values. To see the reduction in noise, it is necessary to keep the y-axis on the same scale as in the previous plot. ```{python} -# | label: poisson-realizations +#| label: poisson-realizations + # Sample multiple realizations with Poisson noise n_replicates_poisson = 10 diff --git a/docs/tutorials/observation_processes_measurements.qmd b/docs/tutorials/observation_processes_measurements.qmd index 3419315a..e6649f47 100644 --- a/docs/tutorials/observation_processes_measurements.qmd +++ b/docs/tutorials/observation_processes_measurements.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -17,11 +17,13 @@ jupyter: name: python3 --- -This tutorial demonstrates how to use the `MeasurementObservation` observation process to model continuous measurement data. We first explain the general framework, then illustrate with a wastewater viral concentration example. +This tutorial demonstrates how to use the `MeasurementObservation` observation process to model continuous measurement data. +We first explain the general framework, then illustrate with a wastewater viral concentration example. ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax import jax.numpy as jnp import numpy as np @@ -42,7 +44,8 @@ from pyrenew.deterministic import DeterministicVariable, DeterministicPMF ## The MeasurementObservation Class -The `MeasurementObservation` class models continuous signals derived from infections. Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). +The `MeasurementObservation` class models continuous signals derived from infections. +Unlike count observations (hospital admissions, deaths), measurements are continuous values that may span orders of magnitude or even be negative (e.g., log-transformed data). **Examples of measurement data:** @@ -58,7 +61,8 @@ All measurement observation processes follow the same two-layer structure as the 1. A deterministic transformation defining the expected measurement value 2. A stochastic observation model -Let $\mu(t)$ denote the expected measurement at time $t$. Observations are modeled as +Let $\mu(t)$ denote the expected measurement at time $t$. +Observations are modeled as $$ Y(t) \sim \text{Noise}(f(\mu(t))), @@ -66,7 +70,8 @@ $$ where $f(\cdot)$ is a transformation (often the identity or log function), and the noise model adds stochastic variation around this transformed prediction. -Subclasses implement `_predicted_obs()` to compute $\mu(t)$ from infections. PyRenew provides the noise model. +Subclasses implement `_predicted_obs()` to compute $\mu(t)$ from infections. +PyRenew provides the noise model. The `MeasurementObservation` base class provides: @@ -79,13 +84,13 @@ The `MeasurementObservation` base class provides: The core convolution structure is shared with count observations, but key aspects differ: -| Aspect | CountObservation | MeasurementObservation | -|--------|--------|--------------| -| Output type | Discrete counts | Continuous values | -| Output space | Linear (expected counts) | Often log-transformed | -| Noise model | Poisson or Negative Binomial | Normal (often on log scale) | -| Scaling | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific | -| Subpop structure | Optional (`CountsBySubpop`) | Inherent (sensor/site effects) | + | Aspect | CountObservation | MeasurementObservation | + | ---------------- | ------------------------------------- | ------------------------------ | + | Output type | Discrete counts | Continuous values | + | Output space | Linear (expected counts) | Often log-transformed | + | Noise model | Poisson or Negative Binomial | Normal (often on log scale) | + | Scaling | Ascertainment rate $\alpha \in [0,1]$ | Domain-specific | + | Subpop structure | Optional (`CountsBySubpop`) | Inherent (sensor/site effects) | ### The noise model @@ -100,10 +105,12 @@ Measurement data typically exhibits **sensor-level variability**: different inst observed ~ Normal(f(\mu(t)) + sensor_mode[sensor], sensor_sd[sensor]) ``` -The sensor-level RVs must implement `sample(n_groups=...)`. Use `VectorizedVariable` to wrap simple distributions: +The sensor-level RVs must implement `sample(n_groups=...)`. +Use `VectorizedVariable` to wrap simple distributions: ```{python} -# | label: noise-model-general +#| label: noise-model-general + # Sensor modes: zero-centered, allowing positive or negative bias sensor_mode_rv = VectorizedVariable( "vec_sensor_mode", @@ -129,11 +136,11 @@ noise = HierarchicalNormalNoise( Measurement observations use three index arrays to map observations to their context: -| Index array | Purpose | -|-------------|---------| -| `times` | Day index for each observation | -| `subpop_indices` | Which infection trajectory (subpopulation) generated each observation | -| `sensor_indices` | Which sensor made each observation (determines noise parameters) | + | Index array | Purpose | + | ---------------- | --------------------------------------------------------------------- | + | `times` | Day index for each observation | + | `subpop_indices` | Which infection trajectory (subpopulation) generated each observation | + | `sensor_indices` | Which sensor made each observation (determines noise parameters) | This flexible indexing supports: @@ -149,7 +156,9 @@ To create a measurement process for your domain, subclass `MeasurementObservatio 2. **`validate()`**: Check parameter validity 3. **`lookback_days()`**: Return the temporal PMF length -The `MeasurementObservation` base class requires a `name` parameter. Observation processes are components in multi-signal models, where each signal must have a unique, meaningful name (e.g., `"wastewater"`, `"air_quality"`). This name prefixes all numpyro sample sites, ensuring distinct identifiers in the inference trace. +The `MeasurementObservation` base class requires a `name` parameter. +Observation processes are components in multi-signal models, where each signal must have a unique, meaningful name (e.g., `"wastewater"`, `"air_quality"`). +This name prefixes all numpyro sample sites, ensuring distinct identifiers in the inference trace. ```python class MyMeasurement(MeasurementObservation): @@ -171,11 +180,9 @@ class MyMeasurement(MeasurementObservation): return len(self.temporal_pmf_rv()) - 1 ``` - ## Measurement Example: Wastewater -To illustrate the framework, we specify a wastewater viral concentration observation process, -based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-multisignal) family of models. +To illustrate the framework, we specify a wastewater viral concentration observation process, based on the [PyRenew-HEW](https://github.com/cdcgov/pyrenew-multisignal) family of models. **The wastewater signal** @@ -227,7 +234,8 @@ The key difference from count data is that measurements are modeled on a continu ### Implementing the Wastewater class ```{python} -# | label: wastewater-class +#| label: wastewater-class + from jax.typing import ArrayLike from pyrenew.metaclass import RandomVariable from pyrenew.observation.noise import MeasurementNoise @@ -265,9 +273,7 @@ class Wastewater(MeasurementObservation): noise : MeasurementNoise Noise model (e.g., HierarchicalNormalNoise). """ - super().__init__( - name=name, temporal_pmf_rv=shedding_kinetics_rv, noise=noise - ) + super().__init__(name=name, temporal_pmf_rv=shedding_kinetics_rv, noise=noise) self.log10_genome_per_infection_rv = log10_genome_per_infection_rv self.ml_per_person_per_day = ml_per_person_per_day @@ -299,15 +305,11 @@ class Wastewater(MeasurementObservation): return convolved # Apply to all subpops (infections shape: n_days x n_subpops) - shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)( - infections - ) + shedding_signal = jax.vmap(convolve_site, in_axes=1, out_axes=1)(infections) # Convert to concentration: genomes per mL genome_copies = 10**log10_genome - concentration = ( - shedding_signal * genome_copies / self.ml_per_person_per_day - ) + concentration = shedding_signal * genome_copies / self.ml_per_person_per_day # Return log-concentration (what we model) return jnp.log(concentration) @@ -320,11 +322,10 @@ class Wastewater(MeasurementObservation): The shedding PMF describes what fraction of total viral shedding occurs on each day after infection: ```{python} -# | label: shedding-pmf +#| label: shedding-pmf + # Peak shedding ~3 days after infection, continues for ~10 days -shedding_pmf = jnp.array( - [0.0, 0.05, 0.15, 0.25, 0.20, 0.15, 0.10, 0.05, 0.03, 0.02] -) +shedding_pmf = jnp.array([0.0, 0.05, 0.15, 0.25, 0.20, 0.15, 0.10, 0.05, 0.03, 0.02]) print(f"PMF sums to: {shedding_pmf.sum():.2f}") shedding_rv = DeterministicPMF("viral_shedding", shedding_pmf) @@ -337,11 +338,10 @@ print(f"Mode: {mode_shedding_day} days, Mean: {mean_shedding_day:.1f} days") ``` ```{python} -# | label: plot-shedding +#| label: plot-shedding + # Visualize the shedding distribution -shedding_df = pd.DataFrame( - {"days": days, "probability": np.array(shedding_pmf)} -) +shedding_df = pd.DataFrame({"days": days, "probability": np.array(shedding_pmf)}) ( p9.ggplot(shedding_df, p9.aes(x="days", y="probability")) @@ -380,7 +380,8 @@ shedding_df = pd.DataFrame( **Genome copies and wastewater volume** ```{python} -# | label: scaling-params +#| label: scaling-params + # Log10 genome copies shed per infection (typical range: 8-10) log10_genome_rv = DeterministicVariable("log10_genome", 9.0) @@ -390,10 +391,11 @@ ml_per_person_per_day = 1000.0 ### Sensor-level noise -For wastewater, a "sensor" is a WWTP/lab pair—the combination of treatment plant and laboratory that determines measurement characteristics: +For wastewater, a "sensor" is a WWTP/lab pair---the combination of treatment plant and laboratory that determines measurement characteristics: ```{python} -# | label: ww-noise-model +#| label: ww-noise-model + # Sensor-level mode: systematic differences between WWTP/lab pairs ww_sensor_mode_rv = VectorizedVariable( "vec_ww_sensor_mode", @@ -417,7 +419,8 @@ ww_noise = HierarchicalNormalNoise( ### Creating the wastewater observation process ```{python} -# | label: create-process +#| label: create-process + ww_process = Wastewater( name="wastewater", shedding_kinetics_rv=shedding_rv, @@ -433,10 +436,13 @@ print(f"Required lookback: {ww_process.lookback_days()} days") ### Timeline alignment -The observation process maintains alignment: day $t$ in output corresponds to day $t$ in input. A temporal PMF of length $L$ covers lags 0 to $L-1$, requiring $L-1$ days of prior history. The method `lookback_days()` returns $L-1$; the first valid observation day is at index `lookback_days()`. +The observation process maintains alignment: day $t$ in output corresponds to day $t$ in input. +A temporal PMF of length $L$ covers lags 0 to $L-1$, requiring $L-1$ days of prior history. +The method `lookback_days()` returns $L-1$; the first valid observation day is at index `lookback_days()`. ```{python} -# | label: helper-function +#| label: helper-function + def first_valid_observation_day(obs_process) -> int: """Return the first day index with complete infection history.""" return obs_process.lookback_days() @@ -447,7 +453,8 @@ def first_valid_observation_day(obs_process) -> int: To see how infections spread into concentrations via shedding kinetics, we simulate from a single-day spike: ```{python} -# | label: simulate-spike +#| label: simulate-spike + n_days = 50 day_one = first_valid_observation_day(ww_process) @@ -478,7 +485,8 @@ with numpyro.handlers.seed(rng_seed=42): We plot the resulting observations starting from the first valid observation day. ```{python} -# | label: plot-spike-infections +#| label: plot-spike-infections + infections_df = pd.DataFrame( { "day": np.arange(n_plot_days), @@ -518,7 +526,8 @@ plot_infections Sampling multiple times from the same infections shows the range of possible observations: ```{python} -# | label: sample-realizations +#| label: sample-realizations + n_samples = 50 ww_samples_list = [] @@ -545,11 +554,10 @@ ww_samples_df = pd.DataFrame(ww_samples_list) ``` ```{python} -# | label: plot-sampled-concentrations +#| label: plot-sampled-concentrations + # Compute mean across samples for each day -mean_by_day = ( - ww_samples_df.groupby("day")["log_concentration"].mean().reset_index() -) +mean_by_day = ww_samples_df.groupby("day")["log_concentration"].mean().reset_index() mean_by_day["sample"] = -1 # Relative peak day for plotting (using mode, not mean, since distribution is skewed) @@ -626,9 +634,11 @@ max_conc = ww_samples_df["log_concentration"].max() ### Sensor-level variability -The previous plot showed variability from repeatedly sampling the entire observation process (resampling sensor parameters and noise each time). In practice, we have multiple physical sensors, each with fixed but unknown characteristics. +The previous plot showed variability from repeatedly sampling the entire observation process (resampling sensor parameters and noise each time). +In practice, we have multiple physical sensors, each with fixed but unknown characteristics. -This plot shows four sensors observing the **same infection spike**. Each sensor has: +This plot shows four sensors observing the **same infection spike**. +Each sensor has: - A **sensor mode** (systematic bias): shifts all observations up or down - A **sensor SD** (measurement precision): determines noise level around predictions @@ -636,14 +646,13 @@ This plot shows four sensors observing the **same infection spike**. Each sensor These parameters are sampled once per sensor, then held fixed across all observations from that sensor. ```{python} -# | label: multi-sensor +#| label: multi-sensor + num_sensors = 4 # Use the same observation times and infections as the sampled-concentrations plot sensor_obs_times = jnp.tile(observation_days, num_sensors) -sensor_ids = jnp.repeat( - jnp.arange(num_sensors, dtype=jnp.int32), len(observation_days) -) +sensor_ids = jnp.repeat(jnp.arange(num_sensors, dtype=jnp.int32), len(observation_days)) subpop_ids = jnp.zeros(num_sensors * len(observation_days), dtype=jnp.int32) with numpyro.handlers.seed(rng_seed=42): @@ -667,11 +676,10 @@ multi_sensor_df = pd.DataFrame( ``` ```{python} -# | label: plot-multi-sensor +#| label: plot-multi-sensor + ( - p9.ggplot( - multi_sensor_df, p9.aes(x="day", y="log_concentration", color="sensor") - ) + p9.ggplot(multi_sensor_df, p9.aes(x="day", y="log_concentration", color="sensor")) + p9.geom_line(size=1) + p9.geom_point(size=2) + p9.labs( @@ -684,29 +692,30 @@ multi_sensor_df = pd.DataFrame( ) ``` -Compare this to the previous plot: here, each colored line represents a distinct physical sensor with its own systematic bias. The vertical spread between sensors reflects differences in sensor modes, while the noise within each line reflects each sensor's measurement precision. During inference, these sensor-specific effects are learned from data. +Compare this to the previous plot: here, each colored line represents a distinct physical sensor with its own systematic bias. +The vertical spread between sensors reflects differences in sensor modes, while the noise within each line reflects each sensor's measurement precision. +During inference, these sensor-specific effects are learned from data. ### Multiple subpopulations -In regional surveillance, each wastewater treatment plant serves a distinct **catchment area** (subpopulation) with its own infection dynamics. The `subpop_indices` array maps each observation to the appropriate infection trajectory. +In regional surveillance, each wastewater treatment plant serves a distinct **catchment area** (subpopulation) with its own infection dynamics. +The `subpop_indices` array maps each observation to the appropriate infection trajectory. This example shows two subpopulations with different epidemic curves: - **Subpopulation 0**: Slow decay (e.g., large urban area with sustained transmission) - **Subpopulation 1**: Fast decay (e.g., smaller community with rapid burnout) -Each subpopulation is observed by its own sensor. The observed concentrations reflect both the underlying infection differences AND the sensor-specific measurement characteristics. +Each subpopulation is observed by its own sensor. +The observed concentrations reflect both the underlying infection differences AND the sensor-specific measurement characteristics. ```{python} -# | label: multi-subpop +#| label: multi-subpop + # Two subpopulations with different infection patterns n_days_mp = 40 -infections_subpop1 = 1000.0 * jnp.exp( - -jnp.arange(n_days_mp) / 20.0 -) # Slow decay -infections_subpop2 = 2000.0 * jnp.exp( - -jnp.arange(n_days_mp) / 10.0 -) # Fast decay +infections_subpop1 = 1000.0 * jnp.exp(-jnp.arange(n_days_mp) / 20.0) # Slow decay +infections_subpop2 = 2000.0 * jnp.exp(-jnp.arange(n_days_mp) / 10.0) # Fast decay infections_multi = jnp.stack([infections_subpop1, infections_subpop2], axis=1) # Two sensors, each observing a different subpopulation @@ -735,7 +744,8 @@ multi_subpop_df = pd.DataFrame( ``` ```{python} -# | label: plot-multi-subpop +#| label: plot-multi-subpop + ( p9.ggplot( multi_subpop_df, @@ -753,9 +763,11 @@ multi_subpop_df = pd.DataFrame( ) ``` -The diverging trajectories reflect the different underlying infection curves. Subpopulation 1 starts higher but decays faster, while Subpopulation 0 maintains more sustained levels. In a full model, you would jointly infer the infection trajectories for each subpopulation while accounting for sensor-specific biases. +The diverging trajectories reflect the different underlying infection curves. +Subpopulation 1 starts higher but decays faster, while Subpopulation 0 maintains more sustained levels. +In a full model, you would jointly infer the infection trajectories for each subpopulation while accounting for sensor-specific biases. ---- +-------------------------------------------------------------------------------- ## Summary diff --git a/docs/tutorials/random_variables.qmd b/docs/tutorials/random_variables.qmd index 51f9385e..30423c5c 100644 --- a/docs/tutorials/random_variables.qmd +++ b/docs/tutorials/random_variables.qmd @@ -7,7 +7,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -17,21 +17,18 @@ jupyter: ## Design principle: all quantities are RandomVariables -In a Bayesian model, all quantities — data, parameters, hyperparameters, derived computations — -are random variables living in a single joint probability model (Gelman et al., BDA3 §1.3). +In a Bayesian model, all quantities --- data, parameters, hyperparameters, derived computations --- are random variables living in a single joint probability model (Gelman et al., BDA3 §1.3). The only distinction is whether a given quantity is known (observed, conditioned on) or unknown (to be inferred). A fixed constant is just a degenerate random variable; an estimated rate is a draw from a prior. Both participate in the same joint distribution. PyRenew embodies this through its `RandomVariable` abstract base class. -All model components implement the same `sample()` interface, -so you can swap a fixed quantity for an estimated one — or vice versa — -without changing any surrounding model code. - +All model components implement the same `sample()` interface, so you can swap a fixed quantity for an estimated one --- or vice versa --- without changing any surrounding model code. ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import numpy as np import numpyro @@ -68,7 +65,8 @@ A degenerate random variable that returns a fixed value. Its `sample()` method simply returns the stored value, unchanged. ```{python} -# | label: deterministic-example +#| label: deterministic-example + ihr_fixed = DeterministicVariable("ihr", 0.01) with numpyro.handlers.seed(rng_seed=0): @@ -76,11 +74,11 @@ with numpyro.handlers.seed(rng_seed=0): print(f"IHR (fixed): {value}") ``` -`DeterministicPMF` specializes this for probability mass functions, -validating at construction time that the values sum to 1: +`DeterministicPMF` specializes this for probability mass functions, validating at construction time that the values sum to 1: ```{python} -# | label: deterministic-pmf-example +#| label: deterministic-pmf-example + delay_pmf = DeterministicPMF( "delay", jnp.array([0.0, 0.1, 0.3, 0.3, 0.2, 0.1]), @@ -94,13 +92,13 @@ with numpyro.handlers.seed(rng_seed=0): ### DistributionalVariable A random variable that draws from a numpyro distribution via `numpyro.sample()`. -The `DistributionalVariable` factory function dispatches to one of two classes -depending on whether the distribution is known at construction time or built at sample time. +The `DistributionalVariable` factory function dispatches to one of two classes depending on whether the distribution is known at construction time or built at sample time. **StaticDistributionalVariable**: the distribution is fully specified at construction time. ```{python} -# | label: static-distributional-example +#| label: static-distributional-example + # IHR with a Beta(2, 198) prior: mean ~1%, moderate uncertainty ihr_estimated = DistributionalVariable("ihr", dist.Beta(2, 198)) print(f"Type: {type(ihr_estimated).__name__}") @@ -109,12 +107,12 @@ with numpyro.handlers.seed(rng_seed=0): print(f"IHR (sampled): {ihr_estimated():.4f}") ``` -**DynamicDistributionalVariable**: the distribution is constructed at sample time -from a callable. This is useful when distribution parameters depend on -other sampled quantities. +**DynamicDistributionalVariable**: the distribution is constructed at sample time from a callable. +This is useful when distribution parameters depend on other sampled quantities. ```{python} -# | label: dynamic-distributional-example +#| label: dynamic-distributional-example + # A Normal whose location is determined at sample time dynamic_rv = DistributionalVariable( "dynamic_normal", @@ -133,7 +131,8 @@ Wraps another `RandomVariable` and applies a deterministic transformation to its This is useful for reparameterizations and derived quantities. ```{python} -# | label: transformed-example +#| label: transformed-example + # Day-of-week effect: Dirichlet draw scaled by 7 # so effects are multiplicative and preserve weekly totals dow_effect = TransformedVariable( @@ -153,17 +152,15 @@ with numpyro.handlers.seed(rng_seed=0): ### Interchangeability -Because all implementations share the `sample()` interface, -you can swap them freely. For example, the `PopulationCounts` observation process -accepts any `RandomVariable` as its `ascertainment_rate_rv`. +Because all implementations share the `sample()` interface, you can swap them freely. +For example, the `PopulationCounts` observation process accepts any `RandomVariable` as its `ascertainment_rate_rv`. The model code is identical whether the rate is fixed or estimated: ```{python} -# | label: interchangeability +#| label: interchangeability + hosp_delay_pmf = jnp.array( - datasets.load_example_infection_admission_interval()[ - "probability_mass" - ].to_numpy() + datasets.load_example_infection_admission_interval()["probability_mass"].to_numpy() ) # Fixed ascertainment rate @@ -185,51 +182,42 @@ hosp_estimated = PopulationCounts( ## The RandomVariable public API -The `RandomVariable` metaclass (defined in `pyrenew.metaclass`) requires -subclasses to implement: +The `RandomVariable` metaclass (defined in `pyrenew.metaclass`) requires subclasses to implement: -| Method | Signature | Purpose | -|--------|-----------|---------| -| `sample` | `sample(**kwargs) -> tuple` | Core computation: return a value, draw from a distribution, or perform a calculation | + | Method | Signature | Purpose | + | -------- | --------------------------- | ------------------------------------------------------------------------------------ | + | `sample` | `sample(**kwargs) -> tuple` | Core computation: return a value, draw from a distribution, or perform a calculation | The metaclass also provides: -| Method | Behavior | -|--------|----------| -| `__call__(**kwargs)` | Alias for `sample(**kwargs)`, so `my_rv()` is equivalent to `my_rv.sample()` | + | Method | Behavior | + | -------------------- | ---------------------------------------------------------------------------- | + | `__call__(**kwargs)` | Alias for `sample(**kwargs)`, so `my_rv()` is equivalent to `my_rv.sample()` | -The `**kwargs` pattern is central to composability: a `RandomVariable` -accepts whatever arguments its `sample()` method needs, and passes -through any additional keyword arguments to internal calls. +The `**kwargs` pattern is central to composability: a `RandomVariable` accepts whatever arguments its `sample()` method needs, and passes through any additional keyword arguments to internal calls. ## Writing a custom RandomVariable -The built-in classes handle most cases, but sometimes you need a component with -custom logic---for instance, one that makes multiple `numpyro.sample()` calls, -performs domain-specific validation, or records derived quantities -via `numpyro.deterministic()`. +The built-in classes handle most cases, but sometimes you need a component with custom logic---for instance, one that makes multiple `numpyro.sample()` calls, performs domain-specific validation, or records derived quantities via `numpyro.deterministic()`. ### Example: ascertainment with day-of-week effects -Hospital admissions data typically shows day-of-week reporting patterns: -fewer admissions are reported on weekends, more on weekdays. +Hospital admissions data typically shows day-of-week reporting patterns: fewer admissions are reported on weekends, more on weekdays. We can model this as a multiplicative adjustment to a baseline ascertainment rate. The predicted hospital admissions on day $t$ are: -$$\lambda_t = \alpha \cdot w_{t \bmod 7} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d$$ +$$ +\lambda_t = \alpha \cdot w_{t \bmod 7} \cdot \sum_{d=0}^{D} I_{t-d} \cdot p_d +$$ -where $\alpha$ is the baseline ascertainment rate, -$w_j$ for $j = 0, \ldots, 6$ are day-of-week multipliers -(positive, summing to 7 so that the weekly total is preserved), -and the summation is the delay convolution. +where $\alpha$ is the baseline ascertainment rate, $w_j$ for $j = 0, \ldots, 6$ are day-of-week multipliers (positive, summing to 7 so that the weekly total is preserved), and the summation is the delay convolution. -We define a custom `RandomVariable` that bundles the -ascertainment rate and day-of-week effect into a single component, -returning a daily rate vector. +We define a custom `RandomVariable` that bundles the ascertainment rate and day-of-week effect into a single component, returning a daily rate vector. ```{python} -# | label: custom-rv-definition +#| label: custom-rv-definition + from jax.typing import ArrayLike @@ -274,15 +262,12 @@ class AscertainmentWithDayOfWeek(RandomVariable): dow_concentration = jnp.asarray(dow_concentration) if dow_concentration.shape != (7,): raise ValueError( - f"dow_concentration must have shape (7,), " - f"got {dow_concentration.shape}" + f"dow_concentration must have shape (7,), got {dow_concentration.shape}" ) if jnp.any(dow_concentration <= 0): raise ValueError("dow_concentration values must be positive") if not (0 <= first_day_offset <= 6): - raise ValueError( - f"first_day_offset must be 0-6, got {first_day_offset}" - ) + raise ValueError(f"first_day_offset must be 0-6, got {first_day_offset}") def sample(self, n_days: int, **kwargs) -> tuple: """ @@ -313,9 +298,7 @@ class AscertainmentWithDayOfWeek(RandomVariable): # Tile the 7-element vector across the timeseries full_cycle = jnp.tile(dow_effect, (n_days // 7) + 1) - daily_dow = full_cycle[ - self.first_day_offset : self.first_day_offset + n_days - ] + daily_dow = full_cycle[self.first_day_offset : self.first_day_offset + n_days] # Combine: daily rate = baseline * day-of-week multiplier daily_rate = baseline * daily_dow @@ -332,7 +315,8 @@ This class bundles three things that belong together: ### Sampling from the custom RV ```{python} -# | label: sample-custom-rv +#| label: sample-custom-rv + ascertainment_rv = AscertainmentWithDayOfWeek( name="hosp", baseline_rate_rv=DistributionalVariable("ihr", dist.Beta(2, 198)), @@ -347,7 +331,8 @@ with numpyro.handlers.seed(rng_seed=42): ``` ```{python} -# | label: plot-custom-rv +#| label: plot-custom-rv + day_labels = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] rates_df = pd.DataFrame( @@ -355,18 +340,14 @@ rates_df = pd.DataFrame( "day": np.arange(n_days), "rate": np.array(daily_rates), "dow": [day_labels[i % 7] for i in range(n_days)], - "day_type": [ - "Weekend" if i % 7 >= 5 else "Weekday" for i in range(n_days) - ], + "day_type": ["Weekend" if i % 7 >= 5 else "Weekday" for i in range(n_days)], } ) ( p9.ggplot(rates_df, p9.aes(x="day", y="rate", fill="day_type")) + p9.geom_col(alpha=0.7, color="black", size=0.3) - + p9.scale_fill_manual( - values={"Weekday": "lightblue", "Weekend": "orange"} - ) + + p9.scale_fill_manual(values={"Weekday": "lightblue", "Weekend": "orange"}) + p9.labs( x="Day", y="Daily Ascertainment Rate", @@ -377,39 +358,35 @@ rates_df = pd.DataFrame( ) ``` -The weekly pattern is clear: weekday rates are higher than weekend rates, -reflecting reduced reporting on weekends. +The weekly pattern is clear: weekday rates are higher than weekend rates, reflecting reduced reporting on weekends. ### Applying to a simulated observation process -We now use the `AscertainmentWithDayOfWeek` instance defined above -to generate day-of-week multipliers within a realistic observation -pipeline, showing how a custom `RandomVariable` integrates into a -data-generating process. +We now use the `AscertainmentWithDayOfWeek` instance defined above to generate day-of-week multipliers within a realistic observation pipeline, showing how a custom `RandomVariable` integrates into a data-generating process. In a hospital surveillance system, the observation pipeline is: 1. **Infections** occur over time (the latent process). 2. **Delay**: each infection leads to a hospital admission after some delay, modeled as a convolution with a delay PMF. -3. **Reporting**: the admission is *reported* with day-of-week effects — fewer reports on weekends, more on weekdays. +3. **Reporting**: the admission is *reported* with day-of-week effects --- fewer reports on weekends, more on weekdays. 4. **Noise**: observed counts include measurement noise. -The day-of-week effect is a *reporting* artifact, so it is -applied **after** the delay convolution. +The day-of-week effect is a *reporting* artifact, so it is applied **after** the delay convolution. We walk through each step explicitly to make the pipeline clear. ```{python} -# | label: simulate-infections +#| label: simulate-infections + # Epidemic curve: exponential growth then decay n_days = 60 infections = 10000.0 * jnp.exp(-((jnp.arange(n_days) - 25.0) ** 2) / 200.0) ``` -We convolve infections with the delay PMF and apply a baseline -ascertainment rate to get the expected admissions before noise: +We convolve infections with the delay PMF and apply a baseline ascertainment rate to get the expected admissions before noise: ```{python} -# | label: delay-convolve +#| label: delay-convolve + baseline_rate = 0.01 expected_admissions = jnp.convolve( baseline_rate * infections, @@ -420,12 +397,11 @@ expected_admissions = jnp.convolve( Now we simulate observed admissions with and without day-of-week effects. Without effects, we add noise directly to the smooth expected curve. -With effects, we first call our custom `ascertainment_rv` to get a -day-of-week multiplier, apply it to the expected admissions -(post-convolution, at the reporting stage), then add noise. +With effects, we first call our custom `ascertainment_rv` to get a day-of-week multiplier, apply it to the expected admissions (post-convolution, at the reporting stage), then add noise. ```{python} -# | label: simulate-with-dow +#| label: simulate-with-dow + n_samples = 20 concentration = 50.0 @@ -446,9 +422,7 @@ for seed in range(n_samples): # Apply day-of-week *after* the delay convolution obs_with_dow = numpyro.sample( "obs_with_dow", - dist.NegativeBinomial2( - expected_admissions * dow_multiplier, concentration - ), + dist.NegativeBinomial2(expected_admissions * dow_multiplier, concentration), ) for i in range(n_days): @@ -471,7 +445,8 @@ for seed in range(n_samples): ``` ```{python} -# | label: plot-comparison +#| label: plot-comparison + results_df = pd.DataFrame(results) ( @@ -491,22 +466,20 @@ results_df = pd.DataFrame(results) ``` The left panel shows smooth variation from noise alone. -The right panel shows the additional weekly oscillation introduced by -day-of-week reporting effects — the sawtooth pattern of weekend dips -and weekday peaks that is characteristic of real hospital admissions data. +The right panel shows the additional weekly oscillation introduced by day-of-week reporting effects --- the sawtooth pattern of weekend dips and weekday peaks that is characteristic of real hospital admissions data. ## Summary ### Choosing a RandomVariable implementation -| Need | Use | -|------|-----| -| Fixed known value | `DeterministicVariable` | -| Fixed known PMF | `DeterministicPMF` | -| Sample from a fixed distribution | `DistributionalVariable` (static) | -| Sample from a distribution parameterized at sample time | `DistributionalVariable` (dynamic, pass a callable) | -| Deterministic transformation of another RV | `TransformedVariable` | -| Multiple sample statements, custom validation, or derived computation | Custom `RandomVariable` subclass | + | Need | Use | + | --------------------------------------------------------------------- | --------------------------------------------------- | + | Fixed known value | `DeterministicVariable` | + | Fixed known PMF | `DeterministicPMF` | + | Sample from a fixed distribution | `DistributionalVariable` (static) | + | Sample from a distribution parameterized at sample time | `DistributionalVariable` (dynamic, pass a callable) | + | Deterministic transformation of another RV | `TransformedVariable` | + | Multiple sample statements, custom validation, or derived computation | Custom `RandomVariable` subclass | ### Writing a custom RandomVariable diff --git a/docs/tutorials/right_truncation.qmd b/docs/tutorials/right_truncation.qmd index 24d5a864..c53b2b5a 100644 --- a/docs/tutorials/right_truncation.qmd +++ b/docs/tutorials/right_truncation.qmd @@ -9,7 +9,7 @@ jupyter: text_representation: extension: .qmd format_name: quarto - format_version: '1.0' + format_version: "1.0" jupytext_version: 1.18.1 kernelspec: display_name: Python 3 (ipykernel) @@ -18,8 +18,9 @@ jupyter: --- ```{python} -# | label: setup -# | output: false +#| label: setup +#| output: false + import jax.numpy as jnp import numpy as np import numpyro @@ -45,7 +46,9 @@ Ignoring this produces a spurious decline in recent counts. PyRenew's observation equation defines the expected observation count as: -$$\mu(t) = \alpha \sum_{s} I(t-s) \, \pi(s)$$ +$$ +\mu(t) = \alpha \sum_{s} I(t-s) \, \pi(s) +$$ where $\alpha$ is the ascertainment rate and $\pi(s)$ is the infection-to-observation delay distribution. @@ -53,13 +56,17 @@ Right-truncation introduces a second delay: the **reporting delay**, which is th PyRenew models this as a multiplicative adjustment applied after the delay convolution. The predicted observation rate becomes: -$$\lambda(t) = F(k_t) \cdot \mu(t)$$ +$$ +\lambda(t) = F(k_t) \cdot \mu(t) +$$ where $F$ is the CDF of the reporting delay and $k_t$ is the number of days between timepoint $t$ and the **data pull date** (the date on which the dataset was extracted from the surveillance system). Concretely, let $T$ denote the last observation day and let $\text{offset} = \text{data pull date} - T$ be the number of additional days between day $T$ and the data pull (the `right_truncation_offset` parameter). Then: -$$k_t = (T - t) + \text{offset} = (T - t) + (\text{data pull date} - T) = \text{data pull date} - t$$ +$$ +k_t = (T - t) + \text{offset} = (T - t) + (\text{data pull date} - T) = \text{data pull date} - t +$$ Because $k_t$ depends only on the data pull date and the timepoint $t$, its behavior is straightforward: timepoints far in the past (small $t$) have large $k_t$, so $F(k_t) \approx 1$ and counts are fully reported. Recent timepoints (large $t$, close to the data pull date) have small $k_t$ and $F(k_t) < 1$, reducing the predicted counts to reflect incomplete reporting. @@ -70,7 +77,8 @@ The reporting delay PMF specifies how quickly events are reported. Given this PMF and the number of days between each timepoint and the data pull date, `compute_prop_already_reported` returns the proportion of events already reported at each timepoint. ```{python} -# | label: reporting-delay-pmf +#| label: reporting-delay-pmf + reporting_delay_pmf = jnp.array([0.4, 0.3, 0.15, 0.08, 0.04, 0.02, 0.01]) days_pmf = np.arange(len(reporting_delay_pmf)) @@ -82,10 +90,9 @@ print(f"CDF: {np.round(cdf, 2)}") ``` ```{python} -# | label: plot-reporting-delay -delay_df = pd.DataFrame( - {"day": days_pmf, "probability": np.array(reporting_delay_pmf)} -) +#| label: plot-reporting-delay + +delay_df = pd.DataFrame({"day": days_pmf, "probability": np.array(reporting_delay_pmf)}) ( p9.ggplot(delay_df, p9.aes(x="day", y="probability")) @@ -100,19 +107,18 @@ delay_df = pd.DataFrame( ``` The `right_truncation_offset` parameter specifies how many additional days of reporting have elapsed beyond the last observation. -An offset of 0 means the data was pulled on the same day as the last observation — only delay-0 reports have arrived for that day. +An offset of 0 means the data was pulled on the same day as the last observation --- only delay-0 reports have arrived for that day. An offset of 3 means three additional days have passed, allowing more reports to trickle in. ```{python} -# | label: compute-proportions +#| label: compute-proportions + n_example = 20 offsets = [0, 2, 4] prop_results = [] for offset in offsets: - prop = compute_prop_already_reported( - reporting_delay_pmf, n_example, offset - ) + prop = compute_prop_already_reported(reporting_delay_pmf, n_example, offset) for i in range(n_example): prop_results.append( { @@ -124,7 +130,8 @@ for offset in offsets: ``` ```{python} -# | label: plot-proportions +#| label: plot-proportions + prop_df = pd.DataFrame(prop_results) prop_df["offset"] = pd.Categorical( prop_df["offset"], @@ -156,11 +163,10 @@ Larger offsets mean more time has elapsed for reports to arrive, shifting the tr We construct two `PopulationCounts` observation processes: one without right-truncation (the default) and one with a `right_truncation_rv` that supplies the reporting delay PMF. ```{python} -# | label: create-processes +#| label: create-processes + hosp_delay_pmf = jnp.array( - datasets.load_example_infection_admission_interval()[ - "probability_mass" - ].to_numpy() + datasets.load_example_infection_admission_interval()["probability_mass"].to_numpy() ) delay_rv = DeterministicPMF("inf_to_hosp_delay", hosp_delay_pmf) ihr_rv = DeterministicVariable("ihr", 0.01) @@ -178,16 +184,15 @@ process_with_trunc = PopulationCounts( ascertainment_rate_rv=ihr_rv, delay_distribution_rv=delay_rv, noise=NegativeBinomialNoise(concentration_rv), - right_truncation_rv=DeterministicPMF( - "reporting_delay", reporting_delay_pmf - ), + right_truncation_rv=DeterministicPMF("reporting_delay", reporting_delay_pmf), ) ``` We simulate an epidemic that is still growing at the end of the observation window. ```{python} -# | label: simulate-and-sample +#| label: simulate-and-sample + day_one = process_no_trunc.lookback_days() n_total = 80 n_plot_days = n_total - day_one @@ -204,7 +209,8 @@ with numpyro.handlers.seed(rng_seed=0): The plot below overlays predicted admissions with and without right-truncation so the early agreement and late divergence are easy to see. ```{python} -# | label: plot-predicted-comparison +#| label: plot-predicted-comparison + pred_rows = [] for i in range(n_plot_days): pred_rows.append( @@ -229,9 +235,7 @@ pred_df["type"] = pd.Categorical( ) ( - p9.ggplot( - pred_df, p9.aes(x="day", y="admissions", color="type", linetype="type") - ) + p9.ggplot(pred_df, p9.aes(x="day", y="admissions", color="type", linetype="type")) + p9.geom_line(size=1) + p9.scale_color_manual(values=["steelblue", "#e41a1c"]) + p9.scale_linetype_manual(values=["solid", "dashed"]) @@ -247,9 +251,8 @@ pred_df["type"] = pd.Categorical( ``` The two curves agree perfectly in the early period when all reports have arrived. -Near the right edge the truncated curve turns downward — recent counts are depressed because reports have not yet arrived. -Without the right-truncation adjustment, a model fit to the dashed curve would infer that the epidemic is slowing down — a dangerous misinterpretation during an active outbreak. - +Near the right edge the truncated curve turns downward --- recent counts are depressed because reports have not yet arrived. +Without the right-truncation adjustment, a model fit to the dashed curve would infer that the epidemic is slowing down --- a dangerous misinterpretation during an active outbreak. ## Sampled observations @@ -257,7 +260,8 @@ Right-truncation also affects the sampled (noisy) observations, not just the pre The noise model draws from a distribution centered on the adjusted predictions. ```{python} -# | label: sample-noisy-comparison +#| label: sample-noisy-comparison + n_samples = 30 noisy_results = [] for seed in range(n_samples): @@ -287,15 +291,14 @@ for seed in range(n_samples): ``` ```{python} -# | label: plot-noisy-comparison +#| label: plot-noisy-comparison + noisy_df = pd.DataFrame(noisy_results) mean_df = noisy_df.groupby(["day", "type"])["admissions"].mean().reset_index() ( p9.ggplot(noisy_df, p9.aes(x="day", y="admissions")) - + p9.geom_line( - p9.aes(group="sample"), alpha=0.2, size=0.4, color="steelblue" - ) + + p9.geom_line(p9.aes(group="sample"), alpha=0.2, size=0.4, color="steelblue") + p9.geom_line( data=mean_df, mapping=p9.aes(x="day", y="admissions"), @@ -314,20 +317,20 @@ mean_df = noisy_df.groupby(["day", "type"])["admissions"].mean().reset_index() The red line is the mean across samples. In the top panel (complete reporting), the mean tracks the growing epidemic. -In the bottom panel (right-truncated), the mean turns downward at the right edge — a spurious decline caused entirely by incomplete reporting. +In the bottom panel (right-truncated), the mean turns downward at the right edge --- a spurious decline caused entirely by incomplete reporting. ## Summary Right-truncation adjustment is enabled by passing a `right_truncation_rv` (reporting delay PMF) at construction time and a `right_truncation_offset` at sample time. -| Parameter | Where | Purpose | -|-----------|-------|---------| -| `right_truncation_rv` | Constructor | Reporting delay PMF | -| `right_truncation_offset` | `sample()` | Days between last observation and data pull | + | Parameter | Where | Purpose | + | ------------------------- | ----------- | ------------------------------------------- | + | `right_truncation_rv` | Constructor | Reporting delay PMF | + | `right_truncation_offset` | `sample()` | Days between last observation and data pull | When either is `None`, the adjustment is disabled and the process behaves identically to one without right-truncation. This makes it straightforward to compare models with and without the adjustment, or to supply the reporting delay as either a fixed PMF or an inferred distribution. In practice, right-truncation adjustment should be enabled when **fitting** to observed data, so the model correctly attributes low recent counts to incomplete reporting rather than a true decline. -However, it should be **disabled when forecasting**: future timepoints have not yet occurred, so there is no reporting delay to account for — applying the adjustment would nonsensically shrink forecasted counts toward zero. +However, it should be **disabled when forecasting**: future timepoints have not yet occurred, so there is no reporting delay to account for --- applying the adjustment would nonsensically shrink forecasted counts toward zero. A typical workflow is to fit with `right_truncation_offset` set to the actual offset, then generate forecasts with `right_truncation_offset=None`. diff --git a/docs_scripts/quarto_python_formatter.py b/docs_scripts/quarto_python_formatter.py deleted file mode 100755 index 42b25a0e..00000000 --- a/docs_scripts/quarto_python_formatter.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python3 -# numpydoc ignore=GL08 -import re -import shutil -import subprocess -import sys -import tempfile -from pathlib import Path -from re import Match - - -def format_python_code(code: str, ruff_args: list[str]) -> str: # numpydoc ignore=RT01 - """Format Python code using Ruff with custom arguments.""" - try: - cmd = ["ruff", "format", "-"] + ruff_args - result = subprocess.run( - cmd, - input=code, - capture_output=True, - text=True, - check=True, - ) - return result.stdout - except subprocess.CalledProcessError: - print("Error: Failed to format Python code with Ruff.", file=sys.stderr) - return code - - -def replace_code_block( - match: Match[str], ruff_args: list[str] -) -> str: # numpydoc ignore=RT01 - """Replace code block with formatted version.""" - return f"{match.group(1)}\n{format_python_code(match.group(2), ruff_args)}{match.group(3)}" - - -def process_file(filepath: Path, ruff_args: list[str]) -> None: # numpydoc ignore=RT01 - """Process the given file, formatting Python code blocks.""" - python_code_block_pattern = r"(```\{python\})(.*?)(```)" - try: - content = filepath.read_text() - formatted_content = re.sub( - python_code_block_pattern, - lambda m: replace_code_block(m, ruff_args), - content, - flags=re.DOTALL, - ) - - with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_file: - temp_file.write(formatted_content) - temp_filepath = Path(temp_file.name) - - shutil.move(str(temp_filepath), str(filepath)) - except OSError as e: - print(f"Error processing file {filepath}: {e}", file=sys.stderr) - sys.exit(1) - - -if __name__ == "__main__": - if len(sys.argv) < 3: - print( - 'Usage: python docs_scripts/quarto_python_formatter.py "RUFF_ARGS" [filename2.qmd ...]' - ) - sys.exit(1) - - ruff_args = sys.argv[1].split() - - missing_files = [file for file in sys.argv[2:] if not Path(file).exists()] - if missing_files: - raise FileNotFoundError( - f"The following file(s) do not exist: {', '.join(missing_files)}." - ) - for filepath in sys.argv[2:]: - path = Path(filepath) - process_file(path, ruff_args) diff --git a/panache.toml b/panache.toml new file mode 100644 index 00000000..798a6847 --- /dev/null +++ b/panache.toml @@ -0,0 +1,7 @@ +flavor = "gfm" +[format] +wrap = "sentence" +[formatters] +python = "ruff" +[linters] +python = "ruff" diff --git a/pyrenew/datasets/hospital_admissions_data/README.md b/pyrenew/datasets/hospital_admissions_data/README.md index a5e5c9a7..0f1ec319 100644 --- a/pyrenew/datasets/hospital_admissions_data/README.md +++ b/pyrenew/datasets/hospital_admissions_data/README.md @@ -14,18 +14,18 @@ Vintaged snapshot of COVID-19 hospital admissions data as it would have been available on 2023-11-06. - **Coverage**: California (CA) only -- **Date range**: 2023-01-01 to 2023-11-06 (~310 days) +- **Date range**: 2023-01-01 to 2023-11-06 (\~310 days) - **Size**: 12 KB, 311 rows - **Use case**: Tutorials, single-jurisdiction model development ### Schema -| Column | Type | Description | -|--------|------|-------------| -| `date` | string | Date in ISO 8601 format (YYYY-MM-DD) | -| `location` | string | State 2-letter abbreviation | -| `daily_hosp_admits` | integer | Daily COVID-19 hospital admissions count | -| `pop` | integer | State population | + | Column | Type | Description | + | ------------------- | ------- | ---------------------------------------- | + | `date` | string | Date in ISO 8601 format (YYYY-MM-DD) | + | `location` | string | State 2-letter abbreviation | + | `daily_hosp_admits` | integer | Daily COVID-19 hospital admissions count | + | `pop` | integer | State population | ## Usage diff --git a/pyrenew/datasets/wastewater_nwss_data/README.md b/pyrenew/datasets/wastewater_nwss_data/README.md index 1a8c80f6..ecf1c324 100644 --- a/pyrenew/datasets/wastewater_nwss_data/README.md +++ b/pyrenew/datasets/wastewater_nwss_data/README.md @@ -11,33 +11,34 @@ ### `fake_nwss.csv` -Synthetic wastewater surveillance data in NWSS (National Wastewater Surveillance System) format. Contains deliberately added noise for public release. +Synthetic wastewater surveillance data in NWSS (National Wastewater Surveillance System) format. +Contains deliberately added noise for public release. - **Jurisdictions**: CA, WA, NM (real states) plus XX, YY, ZZ (fictional) - **WWTPs**: CA (5), WA (4), NM (2), others (4 each) -- **Date range**: 2023-01-01 to 2023-11-06 (~310 days) +- **Date range**: 2023-01-01 to 2023-11-06 (\~310 days) - **Size**: 487 KB, 3,286 rows - **Granularity**: Site-lab-date level (multiple labs per WWTP) - **Use case**: Tutorials, multi-signal model development ### Schema -| Column | Type | Description | -|--------|------|-------------| -| `wwtp_jurisdiction` | string | State/territory abbreviation | -| `wwtp_name` | string | Wastewater treatment plant identifier | -| `county_names` | string | County code | -| `lab_id` | integer | Laboratory identifier | -| `population_served` | integer | Population served by this WWTP | -| `sample_location` | string | Sample collection point (e.g., "wwtp") | -| `sample_matrix` | string | Sample type (e.g., "raw wastewater") | -| `pcr_target_units` | string | Measurement units | -| `pcr_target` | string | Target pathogen (always "sars-cov-2") | -| `pcr_target_avg_conc` | float | Viral RNA concentration | -| `lod_sewage` | float | Limit of detection for this sample | -| `pcr_target_below_lod` | integer | Below detection limit flag (0=above, 1=below) | -| `sample_collect_date` | string | Sample collection date (YYYY-MM-DD) | -| `quality_flag` | string | Data quality flags | + | Column | Type | Description | + | ---------------------- | ------- | --------------------------------------------- | + | `wwtp_jurisdiction` | string | State/territory abbreviation | + | `wwtp_name` | string | Wastewater treatment plant identifier | + | `county_names` | string | County code | + | `lab_id` | integer | Laboratory identifier | + | `population_served` | integer | Population served by this WWTP | + | `sample_location` | string | Sample collection point (e.g., "wwtp") | + | `sample_matrix` | string | Sample type (e.g., "raw wastewater") | + | `pcr_target_units` | string | Measurement units | + | `pcr_target` | string | Target pathogen (always "sars-cov-2") | + | `pcr_target_avg_conc` | float | Viral RNA concentration | + | `lod_sewage` | float | Limit of detection for this sample | + | `pcr_target_below_lod` | integer | Below detection limit flag (0=above, 1=below) | + | `sample_collect_date` | string | Sample collection date (YYYY-MM-DD) | + | `quality_flag` | string | Data quality flags | ## Usage