Skip to content

Conversation

@Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 21, 2025

Consider the case where we would like to approximate a constrained target distribution with density $\pi : \mathcal{X} \to \mathbb{R}{> 0}$ with an unconstrained variational approximation with density $q : \mathbb{R}^d \to \mathbb{R}{> 0}$. The canonical way to deal with this, popularized by the ADVI paper1, is to use a $b$ bijective transformation ("Bijectors") $b : \mathbb{R}^d \to \mathcal{X}$ such that $q$ is augmented into $q_{b}$ as

$$q_{b^{-1}}(z) = q(b^{-1}(z)) {\lvert \mathrm{J}_{b^{-1}}(z) \rvert}$$

Then AdvancedVI needs to solve the problem

$$q_{b^{-1}}^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q_{b^{-1}}, \pi_b) .$$

But notice that the optimization is, in reality, over $q$. Therefore, often times, AdvancedVI needs access to the underlying q. I will refer to this as the "primal" scheme.

Previously, this was done by giving a special treatment to q <: Bijectors.TransformedDistribution through the Bijectors extension. In particular, the Bijectors extension had to add a specialization to a lot of methods that simply unwrap a TransformedDistribution to do something. This behavior is difficult to document and, therefore, wasn't fully explained in the documentation. Furthermore, each of the relevant methods needs to be specialized in the Bijectors extension, which resulted in a multiplicative complexity, especially for unit testing.

This, however, is unnecessary. Instead, there exists an equivalent "dual" problem that operates in unconstrained space by approximating the transformed posterior

$$\pi_b(\eta) = \pi(b^{-1}(\eta)) {\lvert \mathrm{J}_{b^{-1}}(\eta) \rvert} .$$

That is, we can solve the problem

$$q^* = \arg\min_{q \in \mathcal{Q}} \;\; \mathrm{D}(q, \pi_b)$$

and then post-process the output to retrieve $q_{b^{-1}}^*$.

Within this context, this PR removes the Bijectors extension to fix this problem. Here are the reationals:

  • As mentioned above, AdvancedVI doesn't need to implement the primal scheme. In fact, the upcoming interface in Turing is planned to implement the dual scheme above.
  • The new algorithms KLMinNaturalGradDescent, KLMinWassFwdBwd, FisherMinBatchMatch, for example, do not work in constrained support at all, so they can only be used via the dual scheme. So the way that KLMinRepGradDescent and friends implemented the primal scheme is a bit redundant in terms of consistency at this point.

Instead, a tutorial has been added to the documentation on how to use VI with constrained supports via the dual scheme.

Footnotes

  1. Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research, 18(14), 1-45.


include("normallognormal.jl")
include("unconstrdist.jl")
struct Dist{D<:ContinuousMultivariateDistribution}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The content of unconstrdist.jl have been moved here.

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 9a64db7 Previous: 4d8d95e Ratio
normal/RepGradELBO + STL/meanfield/Zygote 2663625575.5 ns 2602386322 ns 1.02
normal/RepGradELBO + STL/meanfield/ReverseDiff 610381229 ns 609712269 ns 1.00
normal/RepGradELBO + STL/meanfield/Mooncake 247490722 ns 245084270 ns 1.01
normal/RepGradELBO + STL/fullrank/Zygote 2035966984 ns 2041835270 ns 1.00
normal/RepGradELBO + STL/fullrank/ReverseDiff 1152744904 ns 1155642268 ns 1.00
normal/RepGradELBO + STL/fullrank/Mooncake 678304005.5 ns 674570404.5 ns 1.01
normal/RepGradELBO/meanfield/Zygote 1609768612 ns 1567343454.5 ns 1.03
normal/RepGradELBO/meanfield/ReverseDiff 300903420 ns 304931634 ns 0.99
normal/RepGradELBO/meanfield/Mooncake 172911707.5 ns 171032782.5 ns 1.01
normal/RepGradELBO/fullrank/Zygote 1161734600 ns 1095442215 ns 1.06
normal/RepGradELBO/fullrank/ReverseDiff 601496820 ns 592074381 ns 1.02
normal/RepGradELBO/fullrank/Mooncake 557909715 ns 554962334.5 ns 1.01

This comment was automatically generated by workflow using github-action-benchmark.

Red-Portal and others added 4 commits November 22, 2025 12:25
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants