-
Notifications
You must be signed in to change notification settings - Fork 19
Add the forward-backward Wasserstein Gaussian variational inference algorithm #210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
| Benchmark suite | Current: 4b34875 | Previous: 85dcaa9 | Ratio |
|---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
4126219477.5 ns |
4031479469.5 ns |
1.02 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
1002333746 ns |
1118649204 ns |
0.90 |
normal/RepGradELBO + STL/meanfield/Mooncake |
923499003 ns |
1234572462 ns |
0.75 |
normal/RepGradELBO + STL/fullrank/Zygote |
4098155528 ns |
3953335724.5 ns |
1.04 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
1511496123.5 ns |
1598412120.5 ns |
0.95 |
normal/RepGradELBO + STL/fullrank/Mooncake |
981564964 ns |
1259515620.5 ns |
0.78 |
normal/RepGradELBO/meanfield/Zygote |
2567299120.5 ns |
2821886790 ns |
0.91 |
normal/RepGradELBO/meanfield/ReverseDiff |
663379206 ns |
771972557 ns |
0.86 |
normal/RepGradELBO/meanfield/Mooncake |
819795409 ns |
1100687154 ns |
0.74 |
normal/RepGradELBO/fullrank/Zygote |
2605812450.5 ns |
2850045117.5 ns |
0.91 |
normal/RepGradELBO/fullrank/ReverseDiff |
826704652 ns |
943912138.5 ns |
0.88 |
normal/RepGradELBO/fullrank/Mooncake |
862928027.5 ns |
1104866502 ns |
0.78 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
6113635200 ns |
5595282208 ns |
1.09 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
2012778478 ns |
2324059761 ns |
0.87 |
normal + bijector/RepGradELBO + STL/meanfield/Mooncake |
4250639615 ns |
4123667267 ns |
1.03 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
6247668262 ns |
5540335701 ns |
1.13 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
2705739852 ns |
2915496029.5 ns |
0.93 |
normal + bijector/RepGradELBO + STL/fullrank/Mooncake |
4420975128.5 ns |
4073600243.5 ns |
1.09 |
normal + bijector/RepGradELBO/meanfield/Zygote |
4501034455.5 ns |
4296709840 ns |
1.05 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
1624340725 ns |
1940028412 ns |
0.84 |
normal + bijector/RepGradELBO/meanfield/Mooncake |
4115187135.5 ns |
3915092871 ns |
1.05 |
normal + bijector/RepGradELBO/fullrank/Zygote |
4579222710 ns |
4297139407 ns |
1.07 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
1855247349 ns |
2164520758 ns |
0.86 |
normal + bijector/RepGradELBO/fullrank/Mooncake |
4223167692.5 ns |
3916935716 ns |
1.08 |
This comment was automatically generated by workflow using github-action-benchmark.
|
AdvancedVI.jl documentation for PR #210 is available at: |
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only did a superficial read of the code, but didn't spot anything major, just a few style nits and a point about bumping the version in Project.toml.
HISTORY.md
Outdated
| @@ -1,3 +1,10 @@ | |||
| # Release 0.6 | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Project.toml probably needs a matching update of the version number.
You could release this is 0.5.1 if you want, since, as far as I understand, it doesn't break any existing code that uses AdvancedVI. It's not wrong to use 0.6 though if you think the change is major enough to warrant it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do export more symbols though. So isn't it supposed to be a minor release?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand it, you don't have to bump the minor version. This thing says
In particular, a package may set version = "0.2.4" when it has feature additions compared to 0.2.3 as long as it remains backward compatible with 0.2.0. See also The version field.
You could argue that any exporting of new symbols is breaking because it could clash with existing symbols from outside your package. That does have a point, but it's also a rule that is constantly broken e.g. by Julia itself, as they introduce new API in minor releases like 1.12.
By the semver standard you are in fact free to do whatever, because
Major version zero (0.y.z) is for initial development. Anything MAY change at any time. The public API SHOULD NOT be considered stable.
I think the question is, do you want users of ADVI, who have specified compat the usual, guarding-against-breakage way with something like ^0.5 to get this version installed by default, or only if they manually bump the compat.
I'm happy if you want to release this as 0.6, just wanted to highlight that you have a choice.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
mhauru
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with this. Would be good if someone who understands the algorithm also approves, but I appreciate that such peopele are few, far between, and busy.
|
Let me ping @yebai for good measure. Hong, are you happy if I go forward with this? |
This adds the forward-backward Wasserstein Gaussian variational inference algorithm by Diao et al.1. This algorithm minimizes the KL divergence by running proximal stochastic gradient descent in the Bures-Wasserstein space. (The metric is the Wasserstein-2 distance and the gradient is the corresponding Bures-Wasserstein gradient.) Since this is a measure-space algorithm, it tends to converge faster than BBVI/ADVI as long as the step size is well-tuned.
Adding this algorithm to
AdvancedVIhas been made possible by the v0.5 update. I plan to add a couple (2~3) new VI algorithms following this for v0.6.Footnotes
Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In International Conference on Machine Learning. PMLR. ↩