Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the NUTS sampler constructor #35

Merged
merged 2 commits into from
Jul 8, 2022
Merged

Add the NUTS sampler constructor #35

merged 2 commits into from
Jul 8, 2022

Conversation

rlouf
Copy link
Member

@rlouf rlouf commented Jun 10, 2022

In this PR we add a function that uses the NUTS kernel implemented in aehmc to build a function that samples from the posterior distribution of a model specified in a ModelInfo instance. It is a first step towards the general-purpose sampler mentioned in #26.

@rlouf rlouf force-pushed the nuts branch 2 times, most recently from 724c384 to 499df75 Compare June 10, 2022 15:20
aemcmc/nuts.py Outdated Show resolved Hide resolved
aemcmc/nuts.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard added the enhancement New feature or request label Jun 22, 2022
@brandonwillard brandonwillard changed the title Add the NUTS sampler Add the NUTS sampler constructor Jun 22, 2022
@rlouf rlouf force-pushed the nuts branch 2 times, most recently from e335301 to 0a2b99b Compare June 23, 2022 12:56
@rlouf
Copy link
Member Author

rlouf commented Jun 23, 2022

My original plan of treating the general case (only some variables sampled using NUTS) was too general so I am working on a simpler version which should be mergeable sooner.

@codecov
Copy link

codecov bot commented Jun 23, 2022

Codecov Report

Merging #35 (9793d2a) into main (3c5f7dc) will decrease coverage by 0.27%.
The diff coverage is 98.07%.

@@             Coverage Diff             @@
##              main      #35      +/-   ##
===========================================
- Coverage   100.00%   99.72%   -0.28%     
===========================================
  Files            4        6       +2     
  Lines          295      358      +63     
  Branches        20       31      +11     
===========================================
+ Hits           295      357      +62     
- Partials         0        1       +1     
Impacted Files Coverage Δ
aemcmc/nuts.py 98.07% <98.07%> (ø)
aemcmc/dists.py 100.00% <0.00%> (ø)
aemcmc/transforms.py 100.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3c5f7dc...9793d2a. Read the comment docs.

@rlouf rlouf force-pushed the nuts branch 4 times, most recently from 24806bd to a6c80c5 Compare June 24, 2022 13:22
@rlouf
Copy link
Member Author

rlouf commented Jun 24, 2022

I finished a first (working) draft of the NUTS sampler constructor. The test passes locally, but we need a new release of aehmc that includes the recent changes on the NUTS kernel API. Here are a few remarks / questions in no specific order:

I left the warmup out; how we integrate it should be part of a broader reflection on the output of aemcmc and how we expect users to interact with it. One solution is to have a build_sampler function that returns a kernel with a adapt keyword argument. Set to true the kernel will run the adaptation, otherwise it just moves the chain. Another one is to have build_sampler return a kernel for adaptation and another for sampling. We should probably open a new issue to discuss the design. Once we have a vague idea I can add the NUTS warmup.

Since this is meant to be a black box it would be nice to output some information about the kernel that was built (transforms, for instance) in a way that is useful to users.

We need to set conventions for the output of kernels. We currently have the kernel's state, the values of the parameters in the transformed and untransformed space, and updates. The Gibbs samplers' output is different. Again, I will open an issue to discuss this further.

I currently pull the default transforms from aeppl, but I think that since these "defaults" are sampler-specific the dispatcher we use for NUTS should live in aemcmc (and the transform mechanism stay in aeppl, of course);

@rlouf
Copy link
Member Author

rlouf commented Jun 28, 2022

To be able to add the warmup we will need to address the following issue in Aehmc: aesara-devs/aehmc#65.

Copy link
Member

@brandonwillard brandonwillard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs some minor docstring changes; otherwise, looks good.

aemcmc/nuts.py Outdated Show resolved Hide resolved
aemcmc/nuts.py Outdated Show resolved Hide resolved
@rlouf
Copy link
Member Author

rlouf commented Jul 1, 2022

I wanted to simplify the tests, and found a subtle bug:

unobserved_rvs = set(model.rvs_to_values.keys()) - set(model.observed_rvs)

doesn't preserve the ordering in rvs_to_values. It is generally not a big deal since users will interact with the unraveled parameters, but it will be a source of confusion for whoever wants to work with the raveled position.

@rlouf rlouf force-pushed the nuts branch 2 times, most recently from db82b72 to 95f685e Compare July 1, 2022 08:28
@rlouf
Copy link
Member Author

rlouf commented Jul 1, 2022

Corrected the bug, simplified the test (untransformed and transformed variables in the same model), good to merge if the tests pass (Codecov complains about a part of transform_backward not being covered by tests, but it is).

aemcmc/nuts.py Outdated Show resolved Hide resolved
@brandonwillard
Copy link
Member

I don't see the apparent coverage patch (or project) issue reported by Codecov, so we can ignore those failures.

@rlouf rlouf force-pushed the nuts branch 3 times, most recently from 9794b3a to 5806c67 Compare July 7, 2022 11:02
@rlouf
Copy link
Member Author

rlouf commented Jul 7, 2022

I added an explicit return type as requested and updated the code / dependencies to include the latest changes in AeHMC.

@rlouf rlouf merged commit 20611eb into aesara-devs:main Jul 8, 2022
@rlouf rlouf deleted the nuts branch July 8, 2022 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants