Skip to content

Conversation

@Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Oct 25, 2025

This adds the forward-backward Wasserstein Gaussian variational inference algorithm by Diao et al.1. This algorithm minimizes the KL divergence by running proximal stochastic gradient descent in the Bures-Wasserstein space. (The metric is the Wasserstein-2 distance and the gradient is the corresponding Bures-Wasserstein gradient.) Since this is a measure-space algorithm, it tends to converge faster than BBVI/ADVI as long as the step size is well-tuned.

Adding this algorithm to AdvancedVI has been made possible by the v0.5 update. I plan to add a couple (2~3) new VI algorithms following this for v0.6.

Footnotes

  1. Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In International Conference on Machine Learning. PMLR.

Red-Portal and others added 3 commits October 25, 2025 10:05
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 4b34875 Previous: 85dcaa9 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 4126219477.5 ns 4031479469.5 ns 1.02
normal/RepGradELBO + STL/meanfield/ReverseDiff 1002333746 ns 1118649204 ns 0.90
normal/RepGradELBO + STL/meanfield/Mooncake 923499003 ns 1234572462 ns 0.75
normal/RepGradELBO + STL/fullrank/Zygote 4098155528 ns 3953335724.5 ns 1.04
normal/RepGradELBO + STL/fullrank/ReverseDiff 1511496123.5 ns 1598412120.5 ns 0.95
normal/RepGradELBO + STL/fullrank/Mooncake 981564964 ns 1259515620.5 ns 0.78
normal/RepGradELBO/meanfield/Zygote 2567299120.5 ns 2821886790 ns 0.91
normal/RepGradELBO/meanfield/ReverseDiff 663379206 ns 771972557 ns 0.86
normal/RepGradELBO/meanfield/Mooncake 819795409 ns 1100687154 ns 0.74
normal/RepGradELBO/fullrank/Zygote 2605812450.5 ns 2850045117.5 ns 0.91
normal/RepGradELBO/fullrank/ReverseDiff 826704652 ns 943912138.5 ns 0.88
normal/RepGradELBO/fullrank/Mooncake 862928027.5 ns 1104866502 ns 0.78
normal + bijector/RepGradELBO + STL/meanfield/Zygote 6113635200 ns 5595282208 ns 1.09
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 2012778478 ns 2324059761 ns 0.87
normal + bijector/RepGradELBO + STL/meanfield/Mooncake 4250639615 ns 4123667267 ns 1.03
normal + bijector/RepGradELBO + STL/fullrank/Zygote 6247668262 ns 5540335701 ns 1.13
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 2705739852 ns 2915496029.5 ns 0.93
normal + bijector/RepGradELBO + STL/fullrank/Mooncake 4420975128.5 ns 4073600243.5 ns 1.09
normal + bijector/RepGradELBO/meanfield/Zygote 4501034455.5 ns 4296709840 ns 1.05
normal + bijector/RepGradELBO/meanfield/ReverseDiff 1624340725 ns 1940028412 ns 0.84
normal + bijector/RepGradELBO/meanfield/Mooncake 4115187135.5 ns 3915092871 ns 1.05
normal + bijector/RepGradELBO/fullrank/Zygote 4579222710 ns 4297139407 ns 1.07
normal + bijector/RepGradELBO/fullrank/ReverseDiff 1855247349 ns 2164520758 ns 0.86
normal + bijector/RepGradELBO/fullrank/Mooncake 4223167692.5 ns 3916935716 ns 1.08

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

AdvancedVI.jl documentation for PR #210 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR210/

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only did a superficial read of the code, but didn't spot anything major, just a few style nits and a point about bumping the version in Project.toml.

HISTORY.md Outdated
@@ -1,3 +1,10 @@
# Release 0.6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Project.toml probably needs a matching update of the version number.

You could release this is 0.5.1 if you want, since, as far as I understand, it doesn't break any existing code that uses AdvancedVI. It's not wrong to use 0.6 though if you think the change is major enough to warrant it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do export more symbols though. So isn't it supposed to be a minor release?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I understand it, you don't have to bump the minor version. This thing says

In particular, a package may set version = "0.2.4" when it has feature additions compared to 0.2.3 as long as it remains backward compatible with 0.2.0. See also The version field.

You could argue that any exporting of new symbols is breaking because it could clash with existing symbols from outside your package. That does have a point, but it's also a rule that is constantly broken e.g. by Julia itself, as they introduce new API in minor releases like 1.12.

By the semver standard you are in fact free to do whatever, because

Major version zero (0.y.z) is for initial development. Anything MAY change at any time. The public API SHOULD NOT be considered stable.

I think the question is, do you want users of ADVI, who have specified compat the usual, guarding-against-breakage way with something like ^0.5 to get this version installed by default, or only if they manually bump the compat.

I'm happy if you want to release this as 0.6, just wanted to highlight that you have a choice.

Red-Portal and others added 4 commits October 29, 2025 18:58
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Red-Portal and others added 4 commits October 30, 2025 03:32
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: Markus Hauru <markus@mhauru.org>
Co-authored-by: Markus Hauru <markus@mhauru.org>
@Red-Portal Red-Portal requested a review from mhauru October 30, 2025 07:38
Red-Portal and others added 2 commits October 30, 2025 03:38
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm happy with this. Would be good if someone who understands the algorithm also approves, but I appreciate that such peopele are few, far between, and busy.

@Red-Portal
Copy link
Member Author

Let me ping @yebai for good measure. Hong, are you happy if I go forward with this?

@Red-Portal Red-Portal merged commit 237e159 into main Nov 1, 2025
34 of 40 checks passed
@Red-Portal Red-Portal deleted the wasserstein_vi branch November 1, 2025 12:35
@Red-Portal Red-Portal added this to the v0.6.0 milestone Nov 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants