Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a function that constructs samplers #45

Merged

Conversation

brandonwillard
Copy link
Member

@brandonwillard brandonwillard commented Jul 9, 2022

This PR provides an initial implementation of #3—and the related interface function mentioned in #26. It provides a general sampler-constructor, aemcmc.basic.construct_sampler, that returns a dict mapping RandomVariables to their sample steps.

The current approach uses a Feature called SamplerTracker to track a dict from RandomVariables to all their discovered sample steps—even when there's more than one potential sampler for the same RandomVariable. Sample steps are discovered by walking the graph with standard local rewriters that write their results to the dict in SamplerTracker. This allows us to maintain the original observation variable graphs in relation to every other un-observed variable (i.e. so we can see when a variable is in a particular hierarchical relationship with another variable, etc.)

In order to get around some DimShuffle annoyances during unification/pattern-matching, a SubsumingElemwise Op was added and is used to replace Elemwise(DimShuffle(x), ...) graphs with SubsumingElemwise(x, ...) graphs (i.e. ones that subsume the DimShuffles). Since SubsumingElemwise inherits from OpFromGraph, those nodes can be expanded later on to reproduce the original Elemwise + DimShuffle sub-graphs.

  • Work out the design of the main "processing loop"
    This loop needs to iterate over all the unobserved RandomVariables and construct samplers for them. While doing so, it needs to use references to the previously constructed samplers' outputs.
  • Do something about RandomVariable rewriting
    The issue here is that RandomVariables that are canonicalized are no longer the same RandomVariables that the user created, so we need a means of keeping a map between the two. Lifting Ops through RandomVariables is one of the main ways this issue shows up. N.B.: This is also one situation in which we could use complete relations (i.e. two-way rewrites).
  • Restrict the kinds of DimShuffles SubsumingElemwise will subsume (i.e. limit to only ones that add the appropriate broadcast dimensions).
    In cases where the original graph was a Elemwise(DimShuffle1(DimShuffle2(x)), ...) and the two DimShuffle* are merged, we will need to un-merge/expand them in order to use SubsumingElemwise.
  • Add initial value checks to tests (i.e. make sure that initial value variables correctly replace their corresponding RandomVariables).
  • Finish refactoring aemcmc.gibbs (e.g. remove combination samplers in favor of construct_sampler, create more local_optimizers, generalize local_optimizer construction, etc.)

@brandonwillard brandonwillard added enhancement New feature or request important refactoring A change that improves the codebase but doesn't necessarily introduce a new feature labels Jul 9, 2022
@brandonwillard brandonwillard self-assigned this Jul 9, 2022
@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 3 times, most recently from 4b90969 to ea2456b Compare July 13, 2022 04:01
@brandonwillard brandonwillard marked this pull request as ready for review July 13, 2022 04:01
@brandonwillard brandonwillard marked this pull request as draft July 13, 2022 04:09
@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 4 times, most recently from 0ea5140 to c8fe1cb Compare July 13, 2022 04:49
@codecov
Copy link

codecov bot commented Jul 13, 2022

Codecov Report

Merging #45 (65c7a37) into main (20611eb) will decrease coverage by 2.54%.
The diff coverage is 97.59%.

❗ Current head 65c7a37 differs from pull request most recent head 39ac1a5. Consider uploading reports for the commit 39ac1a5 to get more accurate results

@@            Coverage Diff             @@
##             main      #45      +/-   ##
==========================================
- Coverage   99.74%   97.20%   -2.55%     
==========================================
  Files           7        9       +2     
  Lines         391      572     +181     
  Branches       31       62      +31     
==========================================
+ Hits          390      556     +166     
- Misses          0        5       +5     
- Partials        1       11      +10     
Impacted Files Coverage Δ
aemcmc/gibbs.py 91.87% <93.05%> (-8.13%) ⬇️
aemcmc/opt.py 98.67% <98.67%> (ø)
aemcmc/basic.py 100.00% <100.00%> (ø)
aemcmc/conjugates.py 100.00% <100.00%> (ø)
aemcmc/dists.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 20611eb...39ac1a5. Read the comment docs.

aemcmc/gibbs.py Outdated Show resolved Hide resolved
aemcmc/conjugates.py Outdated Show resolved Hide resolved
@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 2 times, most recently from 7b1a633 to f335ab3 Compare July 13, 2022 22:47
@brandonwillard brandonwillard linked an issue Jul 13, 2022 that may be closed by this pull request
@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 3 times, most recently from d57903a to f25d53e Compare July 16, 2022 20:03
@brandonwillard brandonwillard marked this pull request as ready for review July 16, 2022 20:08
@brandonwillard
Copy link
Member Author

brandonwillard commented Jul 16, 2022

We now have a complete working example in the test test_basic.py:test_create_gibbs.

@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 2 times, most recently from 9c12ab6 to f2a41f0 Compare July 16, 2022 22:14
@brandonwillard
Copy link
Member Author

I've finished refactoring the sampler steps and filled out the docstings, so this should be ready to merge when/if it passes.

@brandonwillard brandonwillard force-pushed the add-construct_sampler-function branch 4 times, most recently from 65c7a37 to 58882f0 Compare July 22, 2022 06:05
@brandonwillard brandonwillard merged commit 213e0fd into aesara-devs:main Jul 22, 2022
@brandonwillard brandonwillard deleted the add-construct_sampler-function branch July 22, 2022 06:14
@brandonwillard brandonwillard linked an issue Jul 22, 2022 that may be closed by this pull request
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request important refactoring A change that improves the codebase but doesn't necessarily introduce a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make horseshoe_nbinom_gibbs return a kernel Create a framework for matching models to samplers
2 participants