Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1980f93
move gaussian expectation of grad and hess to its own file
Red-Portal Nov 7, 2025
c84b453
add square-root variational newton algorithm
Red-Portal Nov 7, 2025
48daaa0
apply formatter
Red-Portal Nov 7, 2025
3483e8d
add natural gradient descent (variational online Newton)
Red-Portal Nov 7, 2025
8267a98
update docstrings remove redundant comments
Red-Portal Nov 7, 2025
f3790c3
run formatter
Red-Portal Nov 7, 2025
3ba8401
update history
Red-Portal Nov 7, 2025
bca8f55
fix gauss expected grad hess, use in-place operations, add tests
Red-Portal Nov 11, 2025
82e9f15
fix always wrap `hess_buf` with a `Symmetric` (not `Hermitian`)
Red-Portal Nov 11, 2025
8fdecb1
Apply suggestion from @sunxd3
Red-Portal Nov 11, 2025
3ff2c0f
Apply suggestion from @sunxd3
Red-Portal Nov 11, 2025
45b9989
Apply suggestion from @github-actions[bot]
Red-Portal Nov 11, 2025
75e489d
fix bug in init of klminnaturalgraddescent
Red-Portal Nov 11, 2025
6f55a5c
remove unintended benchmark code
Red-Portal Nov 11, 2025
78d3559
update docs
Red-Portal Nov 11, 2025
020634d
fix relax Hermitian to Symmetric in NGVI ensure posdef
Red-Portal Nov 11, 2025
ccf6506
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 11, 2025
9070c82
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 12, 2025
47c8ed2
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 14, 2025
321fb92
fix gauss expected grad hess
Red-Portal Nov 14, 2025
f7f965a
fix callback argument in measure space algorithms
Red-Portal Nov 14, 2025
49236af
fix the positive definite preserving update rule in NGVI
Red-Portal Nov 14, 2025
5dba4c5
add the docs for the natural gradient descent algorithms
Red-Portal Nov 14, 2025
58d5cf5
fix docstrings for the measure-space algorithms
Red-Portal Nov 14, 2025
eec369c
Merge branch 'main' into natural_gradient_vi
Red-Portal Nov 14, 2025
d6a29d2
run formatter
Red-Portal Nov 14, 2025
45c3751
run formatter
Red-Portal Nov 14, 2025
3395851
run formatter
Red-Portal Nov 14, 2025
862fc60
run formatter
Red-Portal Nov 14, 2025
36a4be7
run formatter
Red-Portal Nov 14, 2025
76775c5
run formatter
Red-Portal Nov 14, 2025
d72f61c
run formatter
Red-Portal Nov 14, 2025
87de319
run formatter
Red-Portal Nov 14, 2025
1c6226a
run formatter
Red-Portal Nov 14, 2025
248112b
run formatter
Red-Portal Nov 14, 2025
51a0b3f
run formatter
Red-Portal Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ makedocs(;
"`KLMinRepGradProxDescent`" => "klminrepgradproxdescent.md",
"`KLMinScoreGradDescent`" => "klminscoregraddescent.md",
"`KLMinWassFwdBwd`" => "klminwassfwdbwd.md",
"`KLMinNaturalGradDescent`" => "klminnaturalgraddescent.md",
"`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md",
"`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md",
],
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
Expand Down
77 changes: 77 additions & 0 deletions docs/src/klminnaturalgraddescent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# [`KLMinNaturalGradDescent`](@id klminnaturalgraddescent)

## Description

This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent.
`KLMinNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as variational online Newton[^KR2023].
For nearly-Gaussian targets, NGVI tends to converge very quickly.
If the `ensure_posdef` option is set to `true` (this is the default configuration), then the update rule of [^LSK2020] is used, which guarantees the updated precision matrix is always positive definite.
Since `KLMinNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable.

```@docs
KLMinNaturalGradDescent
```

The associated objective value, which is the ELBO, can be estimated through the following:

```@docs; canonical=false
estimate_objective(
::Random.AbstractRNG,
::KLMinWassFwdBwd,
::MvLocationScale,
::Any;
::Int,
)
```

[^KR2023]: Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. *Journal of Machine Learning Research*, 24(281), 1-46.
[^LSK2020]: Lin, W., Schmidt, M., & Khan, M. E. (2020). Handling the positive-definite constraint in the Bayesian learning rule. In *International Conference on Machine Learning*. PMLR.
## [Methodology](@id klminnaturalgraddescent_method)

This algorithm aims to solve the problem

```math
\mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right)
```

where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters.
That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$.

Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence.
Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999]

```math
\mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right),
```

which is equivalent to the KL up to an additive constant (the evidence).

Suppose we had access to the exact gradients $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$.
NGVI attempts to minimize $\mathcal{L}$ via natural gradient descent, which corresponds to iterating the mirror descent update

```math
\lambda_{t+1} = \argmin_{\lambda \in \Lambda} {\langle \nabla_{\lambda} \mathcal{L}\left(q_{\lambda_t}\right), \lambda - \lambda_t \rangle} + \frac{1}{2 \gamma_t} \mathrm{KL}\left(q, q_{\lambda_t}\right) .
```

This turns out to be equivalent to the update

```math
\lambda_{t+1} = \lambda_{t} - \gamma_t {F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t}) ,
```

where $F(\lambda_t)$ is the Fisher information matrix of $q_{\lambda}$.
That is, natural gradient descent can be viewed as gradient descent with an iterate-dependent preconditioning.
Furthermore, ${F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t})$ is refered to as the *natural gradient* of the KL divergence[^A1998], hence natural gradient variational inference.
Also note that the gradient is taken over the parameters of $q_{\lambda}$.
Therefore, NGVI is parametrization dependent: for the same variational family, different parametrizations will result in different behavior.
However, the pseudo-metric $\mathrm{KL}\left(q, q_{\lambda_t}\right)$ is over measures.
Therefore, NGVI tend to behave as a measure-space algorithm, but technically speaking, not a fully measure-space algorithm.

In practice, we don't have access to $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$ apart from its unbiased estimate.
Regardless, the natural gradient descent/mirror descent updates involving the stochastic estimates have been derived for some variational families.
For instance, Gaussian variational families[^KR2023] and mixture of exponential families[^LKS2019].
As of now, we only implement the Gaussian version.

[^LKS2019]: Lin, W., Khan, M. E., & Schmidt, M. (2019). Fast and simple natural-gradient variational inference with mixture of exponential-family approximations. In *International Conference on Machine Learning*. PMLR.
[^A1998]: Amari, S. I. (1998). Natural gradient works efficiently in learning. *Neural computation*, 10(2), 251-276.
[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.
89 changes: 89 additions & 0 deletions docs/src/klminsqrtnaturalgraddescent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# [`KLMinSqrtNaturalGradDescent`](@id klminsqrtnaturalgraddescent)

## Description

This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent.
`KLMinSqrtNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as square-root variational Newton[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025].
This algorithm operates under the square-root or Cholesky factorization of the covariance matrix parameterization.
This contrasts with [`KLMinNaturalGradDescent`](@ref klminnaturalgraddescent), which operates in the precision matrix parameterization, requiring a matrix inverse at each step.
As a result, the cost of `KLMinSqrtNaturalGradDescent` should be relatively cheaper.
Since `KLMinSqrtNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable.

```@docs
KLMinSqrtNaturalGradDescent
```

The associated objective value, which is the ELBO, can be estimated through the following:

```@docs; canonical=false
estimate_objective(
::Random.AbstractRNG,
::KLMinWassFwdBwd,
::MvLocationScale,
::Any;
::Int,
)
```

[^KMKL2025]: Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. *Transactions of Machine Learning Research*.
[^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*.
[^LDLNKS2023]: Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In *International Conference on Machine Learning*.
[^T2025]: Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. *Journal of the Royal Statistical Society: Series B.*
## [Methodology](@id klminsqrtnaturalgraddescent_method)

This algorithm aims to solve the problem

```math
\mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right)
```

where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters.
That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$.

Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence.
Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999]

```math
\mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right),
```

which is equivalent to the KL up to an additive constant (the evidence).

While `KLMinSqrtNaturalGradDescent` is close to a natural gradient variational inference algorithm, it can be derived in a variety of different ways.
In fact, the update rule has been concurrently developed by several research groups[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025].
Here, we will present the derivation by Kumar *et al.* [^KMKL2025].
Consider the ideal natural gradient descent algorithm discussed [here](@ref klminnaturalgraddescent_method).
This can be viewed as a discretization of the continuous-time dynamics given by the differential equation

```math
\dot{\lambda}_t
=
{F(\lambda)}^{-1} \nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right) .
```

This is also known as the *natural gradient flow*.
Notice that the flow is over the parameters $\lambda_t$.
Therefore, the natural gradient flow depends on the way we parametrize $q_{\lambda}$.
For Gaussian variational families, if we specifically choose the *square-root* (or Cholesky) parametrization such that $q_{\lambda_t} = \mathrm{Normal}(m_t, C_t C_t)$, the flow of $\lambda_t = (m_t, C_t)$ given as

```math
\begin{align*}
\dot{m}_t &= C_t C_t^{\top} \mathbb{E}_{q_{\lambda_t}} \left[ \nabla \log \pi \right]
\\
\dot{C}_t &= C_t M\left( \mathrm{I}_d + C_t^{\top} \mathbb{E}\left[ \nabla^2 \log \pi \right] C_t \right) ,
\end{align*}
```

where $M$ is a $\mathrm{tril}$-like function defined as

```math
{[ M(A) ]}_{ij} = \begin{cases}
0 & \text{if $i > j$} \\
\frac{1}{2} A_{ii} & \text{if $i = j$} \\
A_{ij} & \text{if $i < j$} .
\end{cases}
```

`KLMinSqrtNaturalGradDescent` corresponds to the forward Euler discretization of this flow.

[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.
8 changes: 4 additions & 4 deletions src/algorithms/klminnaturalgraddescent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ This algorithm can be viewed as an instantiation of mirror descent, where the Br
If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness.
This, however, involves an additional set of matrix-matrix system solves that could be costly.

The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation.
If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian.
If the target has only first-order capability, we use Stein's identity.
This algorithm requires second-order information about the target.
If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used.
Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.

# (Keyword) Arguments
- `stepsize::Float64`: Step size.
Expand All @@ -38,7 +38,7 @@ The keyword arguments are as follows:

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target distribution has unconstrained support.
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
Expand Down
8 changes: 4 additions & 4 deletions src/algorithms/klminsqrtnaturalgraddescent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025].

The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation.
If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian.
If the target has only first-order capability, we use Stein's identity.
This algorithm requires second-order information about the target.
If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used.
Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.

# (Keyword) Arguments
- `stepsize::Float64`: Step size.
Expand All @@ -33,7 +33,7 @@ The keyword arguments are as follows:

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target distribution has unconstrained support.
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
Expand Down
8 changes: 4 additions & 4 deletions src/algorithms/klminwassfwdbwd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023].

The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation.
If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian.
If the target has only first-order capability, we use Stein's identity.
This algorithm requires second-order information about the target.
If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used.
Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior.

# (Keyword) Arguments
- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`)
Expand All @@ -33,7 +33,7 @@ The keyword arguments are as follows:

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target distribution has unconstrained support.
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
Expand Down
Loading