diff --git a/docs/make.jl b/docs/make.jl index 2cacf5fcb..7cfafcea0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -27,6 +27,9 @@ makedocs(; "`KLMinRepGradProxDescent`" => "klminrepgradproxdescent.md", "`KLMinScoreGradDescent`" => "klminscoregraddescent.md", "`KLMinWassFwdBwd`" => "klminwassfwdbwd.md", + "`KLMinNaturalGradDescent`" => "klminnaturalgraddescent.md", + "`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md", + "`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md", ], "Variational Families" => "families.md", "Optimization" => "optimization.md", diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md new file mode 100644 index 000000000..011e3b804 --- /dev/null +++ b/docs/src/klminnaturalgraddescent.md @@ -0,0 +1,77 @@ +# [`KLMinNaturalGradDescent`](@id klminnaturalgraddescent) + +## Description + +This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. +`KLMinNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as variational online Newton[^KR2023]. +For nearly-Gaussian targets, NGVI tends to converge very quickly. +If the `ensure_posdef` option is set to `true` (this is the default configuration), then the update rule of [^LSK2020] is used, which guarantees the updated precision matrix is always positive definite. +Since `KLMinNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable. + +```@docs +KLMinNaturalGradDescent +``` + +The associated objective value, which is the ELBO, can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^KR2023]: Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. *Journal of Machine Learning Research*, 24(281), 1-46. +[^LSK2020]: Lin, W., Schmidt, M., & Khan, M. E. (2020). Handling the positive-definite constraint in the Bayesian learning rule. In *International Conference on Machine Learning*. PMLR. +## [Methodology](@id klminnaturalgraddescent_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) +``` + +where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. +That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. + +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999] + +```math + \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), +``` + +which is equivalent to the KL up to an additive constant (the evidence). + +Suppose we had access to the exact gradients $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$. +NGVI attempts to minimize $\mathcal{L}$ via natural gradient descent, which corresponds to iterating the mirror descent update + +```math +\lambda_{t+1} = \argmin_{\lambda \in \Lambda} {\langle \nabla_{\lambda} \mathcal{L}\left(q_{\lambda_t}\right), \lambda - \lambda_t \rangle} + \frac{1}{2 \gamma_t} \mathrm{KL}\left(q, q_{\lambda_t}\right) . +``` + +This turns out to be equivalent to the update + +```math +\lambda_{t+1} = \lambda_{t} - \gamma_t {F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t}) , +``` + +where $F(\lambda_t)$ is the Fisher information matrix of $q_{\lambda}$. +That is, natural gradient descent can be viewed as gradient descent with an iterate-dependent preconditioning. +Furthermore, ${F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t})$ is refered to as the *natural gradient* of the KL divergence[^A1998], hence natural gradient variational inference. +Also note that the gradient is taken over the parameters of $q_{\lambda}$. +Therefore, NGVI is parametrization dependent: for the same variational family, different parametrizations will result in different behavior. +However, the pseudo-metric $\mathrm{KL}\left(q, q_{\lambda_t}\right)$ is over measures. +Therefore, NGVI tend to behave as a measure-space algorithm, but technically speaking, not a fully measure-space algorithm. + +In practice, we don't have access to $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$ apart from its unbiased estimate. +Regardless, the natural gradient descent/mirror descent updates involving the stochastic estimates have been derived for some variational families. +For instance, Gaussian variational families[^KR2023] and mixture of exponential families[^LKS2019]. +As of now, we only implement the Gaussian version. + +[^LKS2019]: Lin, W., Khan, M. E., & Schmidt, M. (2019). Fast and simple natural-gradient variational inference with mixture of exponential-family approximations. In *International Conference on Machine Learning*. PMLR. +[^A1998]: Amari, S. I. (1998). Natural gradient works efficiently in learning. *Neural computation*, 10(2), 251-276. +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md new file mode 100644 index 000000000..1d5c45b7f --- /dev/null +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -0,0 +1,89 @@ +# [`KLMinSqrtNaturalGradDescent`](@id klminsqrtnaturalgraddescent) + +## Description + +This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. +`KLMinSqrtNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as square-root variational Newton[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025]. +This algorithm operates under the square-root or Cholesky factorization of the covariance matrix parameterization. +This contrasts with [`KLMinNaturalGradDescent`](@ref klminnaturalgraddescent), which operates in the precision matrix parameterization, requiring a matrix inverse at each step. +As a result, the cost of `KLMinSqrtNaturalGradDescent` should be relatively cheaper. +Since `KLMinSqrtNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable. + +```@docs +KLMinSqrtNaturalGradDescent +``` + +The associated objective value, which is the ELBO, can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^KMKL2025]: Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. *Transactions of Machine Learning Research*. +[^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*. +[^LDLNKS2023]: Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In *International Conference on Machine Learning*. +[^T2025]: Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. *Journal of the Royal Statistical Society: Series B.* +## [Methodology](@id klminsqrtnaturalgraddescent_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) +``` + +where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. +That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. + +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999] + +```math + \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), +``` + +which is equivalent to the KL up to an additive constant (the evidence). + +While `KLMinSqrtNaturalGradDescent` is close to a natural gradient variational inference algorithm, it can be derived in a variety of different ways. +In fact, the update rule has been concurrently developed by several research groups[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025]. +Here, we will present the derivation by Kumar *et al.* [^KMKL2025]. +Consider the ideal natural gradient descent algorithm discussed [here](@ref klminnaturalgraddescent_method). +This can be viewed as a discretization of the continuous-time dynamics given by the differential equation + +```math +\dot{\lambda}_t += +{F(\lambda)}^{-1} \nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right) . +``` + +This is also known as the *natural gradient flow*. +Notice that the flow is over the parameters $\lambda_t$. +Therefore, the natural gradient flow depends on the way we parametrize $q_{\lambda}$. +For Gaussian variational families, if we specifically choose the *square-root* (or Cholesky) parametrization such that $q_{\lambda_t} = \mathrm{Normal}(m_t, C_t C_t)$, the flow of $\lambda_t = (m_t, C_t)$ given as + +```math +\begin{align*} +\dot{m}_t &= C_t C_t^{\top} \mathbb{E}_{q_{\lambda_t}} \left[ \nabla \log \pi \right] +\\ +\dot{C}_t &= C_t M\left( \mathrm{I}_d + C_t^{\top} \mathbb{E}\left[ \nabla^2 \log \pi \right] C_t \right) , +\end{align*} +``` + +where $M$ is a $\mathrm{tril}$-like function defined as + +```math +{[ M(A) ]}_{ij} = \begin{cases} + 0 & \text{if $i > j$} \\ + \frac{1}{2} A_{ii} & \text{if $i = j$} \\ + A_{ij} & \text{if $i < j$} . +\end{cases} +``` + +`KLMinSqrtNaturalGradDescent` corresponds to the forward Euler discretization of this flow. + +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index d465b6037..e47be81cf 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -9,9 +9,9 @@ This algorithm can be viewed as an instantiation of mirror descent, where the Br If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. This, however, involves an additional set of matrix-matrix system solves that could be costly. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `stepsize::Float64`: Step size. @@ -38,7 +38,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index a26af011c..a8965baaf 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -5,9 +5,9 @@ KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `stepsize::Float64`: Step size. @@ -33,7 +33,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index e04184fb4..570321ac6 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,9 +5,9 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) @@ -33,7 +33,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <: