-
Notifications
You must be signed in to change notification settings - Fork 19
Add natural gradient variational inference algorithms #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| subsampling::Sub = nothing | ||
| end | ||
|
|
||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This portion has been moved to a separate file algorithms/gauss_expected_grad_hess.jl.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
AdvancedVI.jl documentation for PR #211 is available at: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
| Benchmark suite | Current: 49236af | Previous: f7f965a | Ratio |
|---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
3891957202 ns |
3981218216.5 ns |
0.98 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
1135166344 ns |
1149188698 ns |
0.99 |
normal/RepGradELBO + STL/meanfield/Mooncake |
1182439886 ns |
1191740601 ns |
0.99 |
normal/RepGradELBO + STL/fullrank/Zygote |
3887266131.5 ns |
3944970212.5 ns |
0.99 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
1638470865.5 ns |
1690940595 ns |
0.97 |
normal/RepGradELBO + STL/fullrank/Mooncake |
1233128801 ns |
1237776514 ns |
1.00 |
normal/RepGradELBO/meanfield/Zygote |
2757037062.5 ns |
2762774402 ns |
1.00 |
normal/RepGradELBO/meanfield/ReverseDiff |
783215434 ns |
791093181 ns |
0.99 |
normal/RepGradELBO/meanfield/Mooncake |
1070456904 ns |
1084249199 ns |
0.99 |
normal/RepGradELBO/fullrank/Zygote |
2801742993.5 ns |
2822211650 ns |
0.99 |
normal/RepGradELBO/fullrank/ReverseDiff |
969649118 ns |
991218524 ns |
0.98 |
normal/RepGradELBO/fullrank/Mooncake |
1105595092 ns |
1087435619 ns |
1.02 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
5552914513 ns |
5523158001 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
2361965592 ns |
2456796214 ns |
0.96 |
normal + bijector/RepGradELBO + STL/meanfield/Mooncake |
4005359864.5 ns |
3998343804 ns |
1.00 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
5553126335 ns |
5543012671 ns |
1.00 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
3060106239.5 ns |
3121884284 ns |
0.98 |
normal + bijector/RepGradELBO + STL/fullrank/Mooncake |
4124451867.5 ns |
4204003867.5 ns |
0.98 |
normal + bijector/RepGradELBO/meanfield/Zygote |
4200797898.5 ns |
4283445922 ns |
0.98 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
2008641069 ns |
2093976830 ns |
0.96 |
normal + bijector/RepGradELBO/meanfield/Mooncake |
3833597618.5 ns |
3895280534.5 ns |
0.98 |
normal + bijector/RepGradELBO/fullrank/Zygote |
4279512628.5 ns |
4390742008.5 ns |
0.97 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
2272013350 ns |
2342918998 ns |
0.97 |
normal + bijector/RepGradELBO/fullrank/Mooncake |
4002567444.5 ns |
4036735807.5 ns |
0.99 |
This comment was automatically generated by workflow using github-action-benchmark.
sunxd3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure all of these are correct, but worth taking a look?
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
sunxd3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tiny things, looking really good
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, looks good! (Still wait till tests pass?)
|
@sunxd3 Thank you for thoroughly reviewing and spotting various mistakes! |
This adds the following natural gradient VI algorithms:
Natural gradient VI (NGVI) is a family of algorithms that correspond to mirror descent under the Bregman divergence. Since the pseudo-metric is a divergence between distributions, the algorithm can be thought of as a measure-space algorithm. Therefore, empirically, NGVI tends to converge faster than BBVI/ADVI. However, the algorithm also involves quantities defined in terms of variational parameters, so it is not a fully measure-space algorithm. As such, design decisions related to parametrizations and update rules result in different implementations (hence two algorithms in this PR). Furthermore, NGVI is restricted to (mixtures) exponential variational families. The PR only implements the Gaussian variational family variant. Another downside is that the update rules tend to involve operations that are costly ($\mathrm{O}(d^3)$ for a $d$ -dimensional target) and sensitive to numerical errors.
This addresses #1
Footnotes
Khan, M., & Lin, W. (2017, April). Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models. AISTATS. ↩
Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. Journal of Machine Learning Research, 24(281), 1-46. ↩
Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. TMLR. ↩
Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024, July). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. ICML. ↩
Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023, July). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. ICML. ↩
Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. JRSS:B. ↩