From 1980f93ccc18d2a3d935104510169d8727eb6ee7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 02:44:48 -0500 Subject: [PATCH 01/32] move gaussian expectation of grad and hess to its own file --- src/AdvancedVI.jl | 3 +- src/algorithms/gauss_expected_grad_hess.jl | 55 ++++++++++++++++++++++ src/algorithms/klminwassfwdbwd.jl | 54 --------------------- 3 files changed, 57 insertions(+), 55 deletions(-) create mode 100644 src/algorithms/gauss_expected_grad_hess.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 2b82b7b79..f62dbb200 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -352,8 +352,9 @@ include("algorithms/common.jl") export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI -# Other Algorithms +# Natural and Wasserstein gradient descent algorithms +include("algorithms/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") export KLMinWassFwdBwd diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl new file mode 100644 index 000000000..bcd124c8b --- /dev/null +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -0,0 +1,55 @@ + +""" + gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) + +Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. +- `n_samples::Int`: Number of samples used for estimation. +- `grad_buf::AbstractVector`: Buffer for the gradient estimate. +- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. +- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. +""" +function gaussian_expectation_gradient_and_hessian!( + rng::Random.AbstractRNG, + q::MvLocationScale{<:LowerTriangular,<:Normal,L}, + n_samples::Int, + grad_buf::AbstractVector{T}, + hess_buf::AbstractMatrix{T}, + prob, +) where {T<:Real,L} + logπ_avg = zero(T) + fill!(grad_buf, zero(T)) + fill!(hess_buf, zero(T)) + + if LogDensityProblems.capabilities(typeof(prob)) ≤ + LogDensityProblems.LogDensityOrder{1}() + # Use Stein's identity + d = LogDensityProblems.dimension(prob) + u = randn(rng, T, d, n_samples) + z = q.scale*u .+ q.location + for b in 1:n_samples + zb, ub = view(z, :, b), view(u, :, b) + logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) + logπ_avg += logπ/n_samples + grad_buf += ∇logπ/n_samples + hess_buf += ub*(∇logπ/n_samples)' + end + return logπ_avg, grad_buf, hess_buf + else + # Use sample average of the Hessian. + z = rand(rng, q, n_samples) + for b in 1:n_samples + zb = view(z, :, b) + logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( + prob, zb + ) + logπ_avg += logπ/n_samples + grad_buf += ∇logπ/n_samples + hess_buf += ∇2logπ/n_samples + end + return logπ_avg, grad_buf, hess_buf + end +end diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index f834b539a..7f6c27f20 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -41,60 +41,6 @@ The keyword arguments are as follows: subsampling::Sub = nothing end -""" - gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) - -Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. -- `n_samples::Int`: Number of samples used for estimation. -- `grad_buf::AbstractVector`: Buffer for the gradient estimate. -- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. -- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. -""" -function gaussian_expectation_gradient_and_hessian!( - rng::Random.AbstractRNG, - q::MvLocationScale{<:LowerTriangular,<:Normal,L}, - n_samples::Int, - grad_buf::AbstractVector{T}, - hess_buf::AbstractMatrix{T}, - prob, -) where {T<:Real,L} - logπ_avg = zero(T) - fill!(grad_buf, zero(T)) - fill!(hess_buf, zero(T)) - - if LogDensityProblems.capabilities(typeof(prob)) ≤ - LogDensityProblems.LogDensityOrder{1}() - # Use Stein's identity - d = LogDensityProblems.dimension(prob) - u = randn(rng, T, d, n_samples) - z = q.scale*u .+ q.location - for b in 1:n_samples - zb, ub = view(z, :, b), view(u, :, b) - logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) - logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ub*(∇logπ/n_samples)' - end - return logπ_avg, grad_buf, hess_buf - else - # Use sample average of the Hessian. - z = rand(rng, q, n_samples) - for b in 1:n_samples - zb = view(z, :, b) - logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( - prob, zb - ) - logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ∇2logπ/n_samples - end - return logπ_avg, grad_buf, hess_buf - end -end struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} q::Q From c84b453b0379c92109070ab0c4565c43d0510117 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 02:45:44 -0500 Subject: [PATCH 02/32] add square-root variational newton algorithm --- src/AdvancedVI.jl | 3 +- src/algorithms/klminsqrtnaturalgraddescent.jl | 157 +++++++++++++++++ .../algorithms/klminsqrtnaturalgraddescent.jl | 158 ++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 318 insertions(+), 1 deletion(-) create mode 100644 src/algorithms/klminsqrtnaturalgraddescent.jl create mode 100644 test/algorithms/klminsqrtnaturalgraddescent.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index f62dbb200..6ab91eb1a 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -356,7 +356,8 @@ export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI include("algorithms/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") +include("algorithms/klminsqrtnaturalgraddescent.jl") -export KLMinWassFwdBwd +export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent end diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl new file mode 100644 index 000000000..5693b2961 --- /dev/null +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -0,0 +1,157 @@ + +""" + KLMinSqrtNaturalGradDescent(n_samples, stepsize, subsampling) + KLMinSqrtNaturalGradDescent(; n_samples, stepsize, subsampling) + +KL divergence minimization algorithm obtained by discretizing the natural gradient flow under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023]. + +Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. + +# (Keyword) Arguments +- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) +- `stepsize::Float64`: Step size of stochastic proximal gradient descent. +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinSqrtVarNewton` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: + + callback(; rng, iteration, q, info) + +The keyword arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. +""" +@kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + n_samples::Int = 1 + stepsize::Float64 + subsampling::Sub = nothing +end + +struct KLMinSqrtNaturalGradDescentState{Q,P,S,GradBuf,HessBuf} + q::Q + prob::P + iteration::Int + sub_st::S + grad_buf::GradBuf + hess_buf::HessBuf +end + +function init( + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + sub = alg.subsampling + n_dims = LogDensityProblems.dimension(prob) + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{1}() + throw( + ArgumentError( + "`KLMinSqrtNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(sub) ? nothing : init(rng, sub) + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + return KLMinSqrtNaturalGradDescentState(q_init, prob, 0, sub_st, grad_buf, hess_buf) +end + +output(::KLMinSqrtNaturalGradDescent, state) = state.q + +function step( + rng::Random.AbstractRNG, alg::KLMinSqrtNaturalGradDescent, state, callback, objargs...; kwargs... +) + (; n_samples, stepsize, subsampling) = alg + (; q, prob, iteration, sub_st, grad_buf, hess_buf) = state + + m = q.location + C = q.scale + η = convert(eltype(m), stepsize) + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + # Estimate the Wasserstein gradient + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( + rng, q, n_samples, grad_buf, hess_buf, prob_sub + ) + + CtHCmI = C'*-hess_buf*C - I + CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI))/2) + + m′ = m - η * C * (C' * -grad_buf) + C′ = C - η * C * CtHCmI_tril + + q′ = MvLocationScale(m′, C′, q.dist) + + state = KLMinSqrtNaturalGradDescentState(q′, prob, iteration, sub_st′, grad_buf, hess_buf) + elbo = logπ_avg + entropy(q′) + info = merge((elbo=elbo,), sub_inf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, info) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::KLMinSqrtNaturalGradDescent`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + if isnothing(alg.subsampling) + return estimate_objective(rng, obj, q, prob) + else + sub = alg.subsampling + sub_st = init(rng, sub) + return mapreduce(+, 1:length(sub)) do _ + batch, sub_st, _ = step(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + estimate_objective(rng, obj, q, prob_sub) / length(sub) + end + end +end diff --git a/test/algorithms/klminsqrtnaturalgraddescent.jl b/test/algorithms/klminsqrtnaturalgraddescent.jl new file mode 100644 index 000000000..fc70dc81c --- /dev/null +++ b/test/algorithms/klminsqrtnaturalgraddescent.jl @@ -0,0 +1,158 @@ + +@testset "KLMinSqrtNaturalGradDescent" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1.0) + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) + end + + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = normal_meanfield(Random.default_rng(), Float64; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + + @testset "subsampling" begin + n_data = 8 + + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) + @test obj_full ≈ obj_sub rtol=0.1 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = subsamplednormal(Random.default_rng(), n_data; capability) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-2, subsampling) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ab67247b3..db66cb77e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,7 @@ if GROUP == "All" || GROUP == "GENERAL" include("families/location_scale_low_rank.jl") include("algorithms/klminwassfwdbwd.jl") + include("algorithms/klminsqrtnaturalgraddescent.jl") end if GROUP == "All" || GROUP == "AD" From 48daaa08f8fe27f6f82760c11e35e97d98cd2f65 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 02:48:10 -0500 Subject: [PATCH 03/32] apply formatter --- src/algorithms/klminsqrtnaturalgraddescent.jl | 11 +++++++++-- test/algorithms/klminsqrtnaturalgraddescent.jl | 12 +++++++++--- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 5693b2961..0a047e945 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -75,7 +75,12 @@ end output(::KLMinSqrtNaturalGradDescent, state) = state.q function step( - rng::Random.AbstractRNG, alg::KLMinSqrtNaturalGradDescent, state, callback, objargs...; kwargs... + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + state, + callback, + objargs...; + kwargs..., ) (; n_samples, stepsize, subsampling) = alg (; q, prob, iteration, sub_st, grad_buf, hess_buf) = state @@ -107,7 +112,9 @@ function step( q′ = MvLocationScale(m′, C′, q.dist) - state = KLMinSqrtNaturalGradDescentState(q′, prob, iteration, sub_st′, grad_buf, hess_buf) + state = KLMinSqrtNaturalGradDescentState( + q′, prob, iteration, sub_st′, grad_buf, hess_buf + ) elbo = logπ_avg + entropy(q′) info = merge((elbo=elbo,), sub_inf) diff --git a/test/algorithms/klminsqrtnaturalgraddescent.jl b/test/algorithms/klminsqrtnaturalgraddescent.jl index fc70dc81c..7841c6d2d 100644 --- a/test/algorithms/klminsqrtnaturalgraddescent.jl +++ b/test/algorithms/klminsqrtnaturalgraddescent.jl @@ -101,7 +101,9 @@ subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) - alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-3, subsampling + ) obj_full = estimate_objective(alg, q0, model; n_samples=10^5) obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) @@ -121,7 +123,9 @@ T = 10 batchsize = 3 subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) - alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-3, subsampling + ) q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) μ = q.location @@ -145,7 +149,9 @@ T = 1000 batchsize = 1 subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) - alg_sub = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-2, subsampling) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-2, subsampling + ) q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) From 3483e8dbf9b110588b5e48869ccefb1a300933d3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 03:56:49 -0500 Subject: [PATCH 04/32] add natural gradient descent (variational online Newton) --- src/AdvancedVI.jl | 3 +- src/algorithms/klminnaturalgraddescent.jl | 175 +++++++++++++++++++++ test/algorithms/klminnaturalgraddescent.jl | 158 +++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 336 insertions(+), 1 deletion(-) create mode 100644 src/algorithms/klminnaturalgraddescent.jl create mode 100644 test/algorithms/klminnaturalgraddescent.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 6ab91eb1a..7d57e32a0 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -357,7 +357,8 @@ export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI include("algorithms/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") include("algorithms/klminsqrtnaturalgraddescent.jl") +include("algorithms/klminnaturalgraddescent.jl") -export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent +export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent end diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl new file mode 100644 index 000000000..5a296fac5 --- /dev/null +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -0,0 +1,175 @@ + +""" + KLMinNaturalGradDescent(stepsize, ensure_posdef, n_samples, subsampling) + KLMinNaturalGradDescent(; stepsize, ensure_posdef, n_samples, subsampling) + +KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton. +This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence. + +If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. +This, however, involves an additional set of matrix-matrix system solves that could be costly. + +Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. + +# (Keyword) Arguments +- `stepsize::Float64`: Step size of stochastic proximal gradient descent. +- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) +- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: + + callback(; rng, iteration, q, info) + +The keyword arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. +""" +@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + stepsize::Float64 + ensure_posdef::Bool = true + n_samples::Int = 1 + subsampling::Sub = nothing +end + +struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf} + q::Q + prob::P + prec::Prec + iteration::Int + sub_st::S + grad_buf::GradBuf + hess_buf::HessBuf +end + +function init( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + sub = alg.subsampling + n_dims = LogDensityProblems.dimension(prob) + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{1}() + throw( + ArgumentError( + "`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(sub) ? nothing : init(rng, sub) + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + return KLMinNaturalGradDescentState( + q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf + ) +end + +output(::KLMinNaturalGradDescent, state) = state.q + +function step( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + state, + callback, + objargs...; + kwargs..., +) + (; ensure_posdef, n_samples, stepsize, subsampling) = alg + (; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state + + m = mean(q) + S = prec + η = convert(eltype(m), stepsize) + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + # Estimate the Wasserstein gradient + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( + rng, q, n_samples, grad_buf, hess_buf, prob_sub + ) + + # Compute natural gradient descent update + S′ = Hermitian(((1 - η) * S + η * (-hess_buf))) + if ensure_posdef + G_hat = S - (-hess_buf) + S′ += η^2 / 2 * Hermitian(G_hat * (S′ \ G_hat)) + end + m′ = m - η * (S′ \ (-grad_buf)) + + q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist) + + state = KLMinNaturalGradDescentState( + q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf + ) + elbo = logπ_avg + entropy(q′) + info = merge((elbo=elbo,), sub_inf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, info) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::KLMinNaturalGradDescent`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + if isnothing(alg.subsampling) + return estimate_objective(rng, obj, q, prob) + else + sub = alg.subsampling + sub_st = init(rng, sub) + return mapreduce(+, 1:length(sub)) do _ + batch, sub_st, _ = step(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + estimate_objective(rng, obj, q, prob_sub) / length(sub) + end + end +end diff --git a/test/algorithms/klminnaturalgraddescent.jl b/test/algorithms/klminnaturalgraddescent.jl new file mode 100644 index 000000000..e5cd5b558 --- /dev/null +++ b/test/algorithms/klminnaturalgraddescent.jl @@ -0,0 +1,158 @@ + +@testset "KLMinNaturalGradDescent" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1.0) + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) + end + + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = normal_meanfield(Random.default_rng(), Float64; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + + @testset "subsampling" begin + n_data = 8 + + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) + @test obj_full ≈ obj_sub rtol=0.1 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = subsamplednormal(Random.default_rng(), n_data; capability) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-2, subsampling) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index db66cb77e..5840eecbb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -67,6 +67,7 @@ if GROUP == "All" || GROUP == "GENERAL" include("algorithms/klminwassfwdbwd.jl") include("algorithms/klminsqrtnaturalgraddescent.jl") + include("algorithms/klminnaturalgraddescent.jl") end if GROUP == "All" || GROUP == "AD" From 8267a98d3430e64befa6b0065bab721f897c8bd9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 03:58:33 -0500 Subject: [PATCH 05/32] update docstrings remove redundant comments --- src/algorithms/klminnaturalgraddescent.jl | 6 ++---- src/algorithms/klminsqrtnaturalgraddescent.jl | 13 ++++++------- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 5a296fac5..d639579ad 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -12,9 +12,9 @@ This, however, involves an additional set of matrix-matrix system solves that co Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments -- `stepsize::Float64`: Step size of stochastic proximal gradient descent. +- `stepsize::Float64`: Step size. - `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) -- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) +- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. !!! note @@ -108,12 +108,10 @@ function step( prob_sub, sub_st′, sub_inf end - # Estimate the Wasserstein gradient logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - # Compute natural gradient descent update S′ = Hermitian(((1 - η) * S + η * (-hess_buf))) if ensure_posdef G_hat = S - (-hess_buf) diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 0a047e945..68a673e42 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -1,15 +1,15 @@ """ - KLMinSqrtNaturalGradDescent(n_samples, stepsize, subsampling) - KLMinSqrtNaturalGradDescent(; n_samples, stepsize, subsampling) + KLMinSqrtNaturalGradDescent(stepsize, n_samples, subsampling) + KLMinSqrtNaturalGradDescent(; stepsize, n_samples, subsampling) -KL divergence minimization algorithm obtained by discretizing the natural gradient flow under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023]. +KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemmanian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments -- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) -- `stepsize::Float64`: Step size of stochastic proximal gradient descent. +- `stepsize::Float64`: Step size. +- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. !!! note @@ -36,8 +36,8 @@ The keyword arguments are as follows: """ @kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: AbstractVariationalAlgorithm - n_samples::Int = 1 stepsize::Float64 + n_samples::Int = 1 subsampling::Sub = nothing end @@ -99,7 +99,6 @@ function step( prob_sub, sub_st′, sub_inf end - # Estimate the Wasserstein gradient logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) From f3790c3bf0fa29aed88186c4ccbb508d4f842fc0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 04:12:21 -0500 Subject: [PATCH 06/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/klminwassfwdbwd.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 7f6c27f20..8d52bcdde 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -41,7 +41,6 @@ The keyword arguments are as follows: subsampling::Sub = nothing end - struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} q::Q prob::P From 3ba84018e0f2daa4d4daac646fe725a07a93c3d7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 7 Nov 2025 04:38:12 -0500 Subject: [PATCH 07/32] update history --- HISTORY.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 1c4e19251..18ef01ab4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,8 @@ This update adds new variational inference algorithms in light of the flexibilit Specifically, the following measure-space optimization algorithms have been added: - `KLMinWassFwdBwd` + - `KLMinNaturalGradDescent` + - `KLMinSqrtNaturalGradDescent` # Release 0.5 From bca8f55b1b98861e7e156bd7644b458a52284c14 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:21:15 -0500 Subject: [PATCH 08/32] fix gauss expected grad hess, use in-place operations, add tests --- src/algorithms/gauss_expected_grad_hess.jl | 37 +++++++++++--- test/general/gauss_expected_grad_hess.jl | 56 ++++++++++++++++++++++ test/runtests.jl | 2 + 3 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 test/general/gauss_expected_grad_hess.jl diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl index bcd124c8b..f6a194bd4 100644 --- a/src/algorithms/gauss_expected_grad_hess.jl +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -4,6 +4,9 @@ Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. +!!! warning + The resulting `hess_buf` may not be perfectly symmetric due to numerical issues. It is therefore useful to wrap it in a `Symmetric` before usage. + # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. @@ -26,29 +29,49 @@ function gaussian_expectation_gradient_and_hessian!( if LogDensityProblems.capabilities(typeof(prob)) ≤ LogDensityProblems.LogDensityOrder{1}() - # Use Stein's identity + # First-order-only: use Stein/Price identity for the Hessian + # + # E_{z ~ N(m, CC')} ∇2 log π(z) + # = E_{z ~ N(m, CC')} (CC')^{-1} (z - m) ∇ log π(z)T + # = E_{u ~ N(0, I)} C \ (u ∇ log π(z)T) . + # + # Algorithmically, draw u ~ N(0, I), z = C u + m, where C = q.scale. + # Accumulate A = E[ u ∇ log π(z)T ], then map back: H = C \ A. d = LogDensityProblems.dimension(prob) u = randn(rng, T, d, n_samples) - z = q.scale*u .+ q.location + m, C = q.location, q.scale + z = C*u .+ m for b in 1:n_samples zb, ub = view(z, :, b), view(u, :, b) logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ub*(∇logπ/n_samples)' + + rdiv!(∇logπ, n_samples) + ∇logπ_div_nsamples = ∇logπ + + grad_buf[:] .+= ∇logπ_div_nsamples + hess_buf[:, :] .+= ub*∇logπ_div_nsamples' end + hess_buf[:, :] .= C \ hess_buf return logπ_avg, grad_buf, hess_buf else - # Use sample average of the Hessian. + # Second-order: use naive sample average z = rand(rng, q, n_samples) for b in 1:n_samples zb = view(z, :, b) logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( prob, zb ) + + rdiv!(∇logπ, n_samples) + ∇logπ_div_nsamples = ∇logπ + + rdiv!(∇2logπ, n_samples) + ∇2logπ_div_nsamples = ∇2logπ + logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ∇2logπ/n_samples + grad_buf[:] .+= ∇logπ_div_nsamples + hess_buf[:, :] .+= ∇2logπ_div_nsamples end return logπ_avg, grad_buf, hess_buf end diff --git a/test/general/gauss_expected_grad_hess.jl b/test/general/gauss_expected_grad_hess.jl new file mode 100644 index 000000000..8cb8e9908 --- /dev/null +++ b/test/general/gauss_expected_grad_hess.jl @@ -0,0 +1,56 @@ + +using BenchmarkTools + +struct TestQuad{S,C} + Σ::S + cap::C +end + +function LogDensityProblems.logdensity(model::TestQuad, x) + Σ = model.Σ + return -x'*Σ*x/2 +end + +function LogDensityProblems.logdensity_and_gradient(model::TestQuad, x) + Σ = model.Σ + return (LogDensityProblems.logdensity(model, x), -Σ*x) +end + +function LogDensityProblems.logdensity_gradient_and_hessian(model::TestQuad, x) + Σ = model.Σ + ℓp, ∇ℓp = LogDensityProblems.logdensity_and_gradient(model, x) + return (ℓp, ∇ℓp, -Σ) +end + +function LogDensityProblems.dimension(model::TestQuad) + return size(model.Σ, 1) +end + +function LogDensityProblems.capabilities(::Type{TestQuad{S,C}}) where {S,C} + return C() +end + +@testset "gauss_expected_grad_hess" begin + n_samples = 10^6 + d = 2 + Σ = [2.0 -0.1; -0.1 2.0] + q = FullRankGaussian(ones(d), LowerTriangular(diagm(fill(0.1, d)))) + + # True expected gradient is E_{x ~ N(μ, 1)} -Σ x = -Σ μ + # True expected Hessian is E_{x ~ N(μ, 1)} -Σ = -Σ + E_∇ℓπ = -Σ*q.location + E_∇2ℓπ = -Σ + + @testset "$(cap)-order capability" for cap in [ + LogDensityProblems.LogDensityOrder{1}(), LogDensityProblems.LogDensityOrder{2}() + ] + grad_buf = zeros(d) + hess_buf = zeros(d, d) + prob = TestQuad(Σ, cap) + display(@benchmark AdvancedVI.gaussian_expectation_gradient_and_hessian!( + Random.default_rng(), $q, $n_samples, $grad_buf, $hess_buf, $prob + )) + @test grad_buf ≈ E_∇ℓπ atol=1e-1 + @test hess_buf ≈ E_∇2ℓπ atol=1e-1 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5840eecbb..105b5bec3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,6 +65,7 @@ if GROUP == "All" || GROUP == "GENERAL" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") + include("general/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") include("algorithms/klminsqrtnaturalgraddescent.jl") include("algorithms/klminnaturalgraddescent.jl") @@ -85,3 +86,4 @@ if GROUP == "All" || GROUP == "AD" include("algorithms/scoregradelbo_locationscale.jl") include("algorithms/scoregradelbo_locationscale_bijectors.jl") end + From 82e9f156cc593890ba38a33c2f5fa859a27a3f48 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:22:33 -0500 Subject: [PATCH 09/32] fix always wrap `hess_buf` with a `Symmetric` (not `Hermitian`) --- src/algorithms/klminnaturalgraddescent.jl | 12 ++++++------ src/algorithms/klminsqrtnaturalgraddescent.jl | 2 +- src/algorithms/klminwassfwdbwd.jl | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index d639579ad..6a53459de 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -1,7 +1,7 @@ """ - KLMinNaturalGradDescent(stepsize, ensure_posdef, n_samples, subsampling) - KLMinNaturalGradDescent(; stepsize, ensure_posdef, n_samples, subsampling) + KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling) + KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling) KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton. This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence. @@ -13,8 +13,8 @@ Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variatio # (Keyword) Arguments - `stepsize::Float64`: Step size. -- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) - `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) +- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. !!! note @@ -42,8 +42,8 @@ The keyword arguments are as follows: @kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: AbstractVariationalAlgorithm stepsize::Float64 - ensure_posdef::Bool = true n_samples::Int = 1 + ensure_posdef::Bool = true subsampling::Sub = nothing end @@ -112,9 +112,9 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - S′ = Hermitian(((1 - η) * S + η * (-hess_buf))) + S′ = Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) if ensure_posdef - G_hat = S - (-hess_buf) + G_hat = S - Symmetric(-hess_buf) S′ += η^2 / 2 * Hermitian(G_hat * (S′ \ G_hat)) end m′ = m - η * (S′ \ (-grad_buf)) diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 68a673e42..fd53cd94d 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -103,7 +103,7 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - CtHCmI = C'*-hess_buf*C - I + CtHCmI = C'*Symmetric(-hess_buf)*C - I CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI))/2) m′ = m - η * C * (C' * -grad_buf) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 8d52bcdde..daefadd07 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -101,7 +101,7 @@ function step( ) m′ = m - η * (-grad_buf) - M = I - η*Hermitian(-hess_buf) + M = I - η*Symmetric(-hess_buf) Σ_half = Hermitian(M*Σ*M) # Compute the JKO proximal operator From 8fdecb16a10ba2d5c98123d2d3758d28089529b1 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:26:36 -0500 Subject: [PATCH 10/32] Apply suggestion from @sunxd3 Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- src/algorithms/klminsqrtnaturalgraddescent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index fd53cd94d..28c5c1c39 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -13,7 +13,7 @@ Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variatio - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. !!! note - The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinSqrtVarNewton` does not support amortization or structured variational families. + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinSqrtNaturalGradDescent` does not support amortization or structured variational families. # Output - `q`: The last iterate of the algorithm. From 3ff2c0f42fdcb1f25677477e2c83ab4a2f818ce7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:26:48 -0500 Subject: [PATCH 11/32] Apply suggestion from @sunxd3 Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- src/algorithms/klminsqrtnaturalgraddescent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 28c5c1c39..22531bae8 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -3,7 +3,7 @@ KLMinSqrtNaturalGradDescent(stepsize, n_samples, subsampling) KLMinSqrtNaturalGradDescent(; stepsize, n_samples, subsampling) -KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemmanian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. +KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. From 45b998958dd68939301b957b8e495f9302c4a582 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:27:32 -0500 Subject: [PATCH 12/32] Apply suggestion from @github-actions[bot] Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 105b5bec3..0d02d0168 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -86,4 +86,3 @@ if GROUP == "All" || GROUP == "AD" include("algorithms/scoregradelbo_locationscale.jl") include("algorithms/scoregradelbo_locationscale_bijectors.jl") end - From 75e489dff789fed6ab60e34d251220e8199578c8 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:27:58 -0500 Subject: [PATCH 13/32] fix bug in init of klminnaturalgraddescent --- src/algorithms/klminnaturalgraddescent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 6a53459de..4756d1a9a 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -77,7 +77,7 @@ function init( grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) return KLMinNaturalGradDescentState( - q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf + q_init, prob, inv(cov(q_init)), 0, sub_st, grad_buf, hess_buf ) end From 6f55a5c281b7b0d193795261db2f64969739dc3e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:28:08 -0500 Subject: [PATCH 14/32] remove unintended benchmark code --- test/general/gauss_expected_grad_hess.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/general/gauss_expected_grad_hess.jl b/test/general/gauss_expected_grad_hess.jl index 8cb8e9908..ee254a6a9 100644 --- a/test/general/gauss_expected_grad_hess.jl +++ b/test/general/gauss_expected_grad_hess.jl @@ -1,6 +1,4 @@ -using BenchmarkTools - struct TestQuad{S,C} Σ::S cap::C @@ -47,9 +45,9 @@ end grad_buf = zeros(d) hess_buf = zeros(d, d) prob = TestQuad(Σ, cap) - display(@benchmark AdvancedVI.gaussian_expectation_gradient_and_hessian!( - Random.default_rng(), $q, $n_samples, $grad_buf, $hess_buf, $prob - )) + AdvancedVI.gaussian_expectation_gradient_and_hessian!( + Random.default_rng(), q, n_samples, grad_buf, hess_buf, prob + ) @test grad_buf ≈ E_∇ℓπ atol=1e-1 @test hess_buf ≈ E_∇2ℓπ atol=1e-1 end From 78d3559edf48bae7eb2ff4e90cd7a48850c3c56b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:29:53 -0500 Subject: [PATCH 15/32] update docs --- src/algorithms/gauss_expected_grad_hess.jl | 4 +++- src/algorithms/klminnaturalgraddescent.jl | 4 +++- src/algorithms/klminsqrtnaturalgraddescent.jl | 4 +++- src/algorithms/klminwassfwdbwd.jl | 4 +++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl index f6a194bd4..457439d14 100644 --- a/src/algorithms/gauss_expected_grad_hess.jl +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -2,7 +2,9 @@ """ gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) -Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. +Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. +For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. +Otherwise, it uses Stein's identity. !!! warning The resulting `hess_buf` may not be perfectly symmetric due to numerical issues. It is therefore useful to wrap it in a `Symmetric` before usage. diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 4756d1a9a..5412a5327 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -9,7 +9,9 @@ This algorithm can be viewed as an instantiation of mirror descent, where the Br If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. This, however, involves an additional set of matrix-matrix system solves that could be costly. -Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments - `stepsize::Float64`: Step size. diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 22531bae8..052e4124c 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -5,7 +5,9 @@ KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. -Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments - `stepsize::Float64`: Step size. diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index daefadd07..14bd52dbe 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,7 +5,9 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. -Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) From 020634db43b61670c68db1beb177b6cef39b9727 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 11 Nov 2025 11:36:35 -0500 Subject: [PATCH 16/32] fix relax Hermitian to Symmetric in NGVI ensure posdef --- src/algorithms/klminnaturalgraddescent.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 5412a5327..88dafc5fc 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -117,7 +117,7 @@ function step( S′ = Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) if ensure_posdef G_hat = S - Symmetric(-hess_buf) - S′ += η^2 / 2 * Hermitian(G_hat * (S′ \ G_hat)) + S′ += η^2 / 2 * Symmetric(G_hat * (S′ \ G_hat)) end m′ = m - η * (S′ \ (-grad_buf)) From 321fb92abf3555e28b876c859efdf194a7e11e14 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 09:06:04 -0500 Subject: [PATCH 17/32] fix gauss expected grad hess --- src/algorithms/gauss_expected_grad_hess.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl index 457439d14..466292b04 100644 --- a/src/algorithms/gauss_expected_grad_hess.jl +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -35,7 +35,7 @@ function gaussian_expectation_gradient_and_hessian!( # # E_{z ~ N(m, CC')} ∇2 log π(z) # = E_{z ~ N(m, CC')} (CC')^{-1} (z - m) ∇ log π(z)T - # = E_{u ~ N(0, I)} C \ (u ∇ log π(z)T) . + # = E_{u ~ N(0, I)} C' \ (u ∇ log π(z)T) . # # Algorithmically, draw u ~ N(0, I), z = C u + m, where C = q.scale. # Accumulate A = E[ u ∇ log π(z)T ], then map back: H = C \ A. @@ -54,7 +54,7 @@ function gaussian_expectation_gradient_and_hessian!( grad_buf[:] .+= ∇logπ_div_nsamples hess_buf[:, :] .+= ub*∇logπ_div_nsamples' end - hess_buf[:, :] .= C \ hess_buf + hess_buf[:, :] .= C' \ hess_buf return logπ_avg, grad_buf, hess_buf else # Second-order: use naive sample average From f7f965a6892a4688d679e3999d5bf107801e0a7b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 09:07:56 -0500 Subject: [PATCH 18/32] fix callback argument in measure space algorithms --- src/algorithms/klminnaturalgraddescent.jl | 2 +- src/algorithms/klminsqrtnaturalgraddescent.jl | 2 +- src/algorithms/klminwassfwdbwd.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 88dafc5fc..972de08d9 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -130,7 +130,7 @@ function step( info = merge((elbo=elbo,), sub_inf) if !isnothing(callback) - info′ = callback(; rng, iteration, q, info) + info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end state, false, info diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index 052e4124c..a26af011c 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -120,7 +120,7 @@ function step( info = merge((elbo=elbo,), sub_inf) if !isnothing(callback) - info′ = callback(; rng, iteration, q, info) + info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end state, false, info diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 14bd52dbe..e04184fb4 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -115,7 +115,7 @@ function step( info = merge((elbo=elbo,), sub_inf) if !isnothing(callback) - info′ = callback(; rng, iteration, q, info) + info′ = callback(; rng, iteration, q=q′, info) info = !isnothing(info′) ? merge(info′, info) : info end state, false, info From 49236af5835c2cd3b300d93b6daff0cbac46e0eb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 10:28:19 -0500 Subject: [PATCH 19/32] fix the positive definite preserving update rule in NGVI --- src/algorithms/gauss_expected_grad_hess.jl | 2 +- src/algorithms/klminnaturalgraddescent.jl | 31 ++++++++++++++++------ 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl index 466292b04..af5c21b22 100644 --- a/src/algorithms/gauss_expected_grad_hess.jl +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -19,7 +19,7 @@ Otherwise, it uses Stein's identity. """ function gaussian_expectation_gradient_and_hessian!( rng::Random.AbstractRNG, - q::MvLocationScale{<:LowerTriangular,<:Normal,L}, + q::MvLocationScale{<:LinearAlgebra.AbstractTriangular,<:Normal,L}, n_samples::Int, grad_buf::AbstractVector{T}, hess_buf::AbstractMatrix{T}, diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index 972de08d9..d465b6037 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -49,10 +49,11 @@ The keyword arguments are as follows: subsampling::Sub = nothing end -struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf} +struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov,GradBuf,HessBuf} q::Q prob::P prec::Prec + qcov::QCov iteration::Int sub_st::S grad_buf::GradBuf @@ -78,8 +79,13 @@ function init( sub_st = isnothing(sub) ? nothing : init(rng, sub) grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + scale = q_init.scale + qcov = Hermitian(scale*scale') + scale_inv = inv(scale) + prec_chol = scale_inv' + prec = Hermitian(prec_chol*prec_chol') return KLMinNaturalGradDescentState( - q_init, prob, inv(cov(q_init)), 0, sub_st, grad_buf, hess_buf + q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf ) end @@ -94,7 +100,7 @@ function step( kwargs..., ) (; ensure_posdef, n_samples, stepsize, subsampling) = alg - (; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state + (; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state m = mean(q) S = prec @@ -114,17 +120,26 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - S′ = Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) - if ensure_posdef + S′ = if ensure_posdef + # Udpate rule guaranteeing positive definiteness in the proof of Theorem 1. + # Lin, W., Schmidt, M., & Khan, M. E. + # Handling the positive-definite constraint in the Bayesian learning rule. + # In ICML 2020. G_hat = S - Symmetric(-hess_buf) - S′ += η^2 / 2 * Symmetric(G_hat * (S′ \ G_hat)) + Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat) + else + Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) end m′ = m - η * (S′ \ (-grad_buf)) - q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist) + prec_chol = cholesky(S′).L + prec_chol_inv = inv(prec_chol) + scale = prec_chol_inv' + qcov = Hermitian(scale*scale') + q′ = MvLocationScale(m′, scale, q.dist) state = KLMinNaturalGradDescentState( - q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf + q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf ) elbo = logπ_avg + entropy(q′) info = merge((elbo=elbo,), sub_inf) From 5dba4c5be2189cabf2dd64aec88a01f1df7b5d80 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:04:47 -0500 Subject: [PATCH 20/32] add the docs for the natural gradient descent algorithms --- docs/make.jl | 3 + docs/src/klminnaturalgraddescent.md | 75 +++++++++++++++++++++ docs/src/klminsqrtnaturalgraddescent.md | 87 +++++++++++++++++++++++++ 3 files changed, 165 insertions(+) create mode 100644 docs/src/klminnaturalgraddescent.md create mode 100644 docs/src/klminsqrtnaturalgraddescent.md diff --git a/docs/make.jl b/docs/make.jl index 2cacf5fcb..7cfafcea0 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -27,6 +27,9 @@ makedocs(; "`KLMinRepGradProxDescent`" => "klminrepgradproxdescent.md", "`KLMinScoreGradDescent`" => "klminscoregraddescent.md", "`KLMinWassFwdBwd`" => "klminwassfwdbwd.md", + "`KLMinNaturalGradDescent`" => "klminnaturalgraddescent.md", + "`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md", + "`KLMinSqrtNaturalGradDescent`" => "klminsqrtnaturalgraddescent.md", ], "Variational Families" => "families.md", "Optimization" => "optimization.md", diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md new file mode 100644 index 000000000..482fefaa5 --- /dev/null +++ b/docs/src/klminnaturalgraddescent.md @@ -0,0 +1,75 @@ +# [`KLMinNaturalGradDescent`](@id klminnaturalgraddescent) + +## Description + +This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. +`KLMinNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as variational online Newton[^KR2023]. +For nearly-Gaussian targets, NGVI tends to converge very quickly. +If the `ensure_posdef` option is set to `true` (this is the default configuration), then the update rule of [^LSK2020] is used, which guarantees the updated precision matrix is always positive definite. +Since `KLMinNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable. + +```@docs +KLMinNaturalGradDescent +``` + +The associated objective value, which is the ELBO, can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^KR2023]: Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. *Journal of Machine Learning Research*, 24(281), 1-46. +[^LSK2020]: Lin, W., Schmidt, M., & Khan, M. E. (2020). Handling the positive-definite constraint in the Bayesian learning rule. In *International Conference on Machine Learning*. PMLR. + +## [Methodology](@id klminnaturalgraddescent_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) +``` +where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. +That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. + +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999] + +```math + \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), +``` +which is equivalent to the KL up to an additive constant (the evidence). + +Suppose we had access to the exact gradients $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$. +NGVI attempts to minimize $\mathcal{L}$ via natural gradient descent, which corresponds to iterating the mirror descent update + +```math +\lambda_{t+1} = \argmin_{\lambda \in \Lambda} {\langle \nabla_{\lambda} \mathcal{L}\left(q_{\lambda_t}\right), \lambda - \lambda_t \rangle} + \frac{1}{2 \gamma_t} \mathrm{KL}\left(q, q_{\lambda_t}\right) . +``` + +This turns out to be equivalent to the update + +```math +\lambda_{t+1} = \lambda_{t} - \gamma_t {F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t}) , +``` +where $F(\lambda_t)$ is the Fisher information matrix of $q_{\lambda}$. +That is, natural gradient descent can be viewed as gradient descent with an iterate-dependent preconditioning. +Furthermore, ${F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t})$ is refered to as the *natural gradient* of the KL divergence[^A1998], hence natural gradient variational inference. +Also note that the gradient is taken over the parameters of $q_{\lambda}$. +Therefore, NGVI is parametrization dependent: for the same variational family, different parametrizations will result in different behavior. +However, the pseudo-metric $\mathrm{KL}\left(q, q_{\lambda_t}\right)$ is over measures. +Therefore, NGVI tend to behave as a measure-space algorithm, but technically speaking, not a fully measure-space algorithm. + +In practice, we don't have access to $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$ apart from its unbiased estimate. +Regardless, the natural gradient descent/mirror descent updates involving the stochastic estimates have been derived for some variational families. +For instance, Gaussian variational families[^KR2023] and mixture of exponential families[^LKS2019]. +As of now, we only implement the Gaussian version. + +[^LKS2019]: Lin, W., Khan, M. E., & Schmidt, M. (2019). Fast and simple natural-gradient variational inference with mixture of exponential-family approximations. In *International Conference on Machine Learning*. PMLR. +[^A1998]: Amari, S. I. (1998). Natural gradient works efficiently in learning. *Neural computation*, 10(2), 251-276. +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md new file mode 100644 index 000000000..17e2e32f7 --- /dev/null +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -0,0 +1,87 @@ + +# [`KLMinSqrtNaturalGradDescent`](@id klminsqrtnaturalgraddescent) + +## Description + +This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running natural gradient descent. +`KLMinSqrtNaturalGradDescent` is a specific implementation of natural gradient variational inference (NGVI) also known as square-root variational Newton[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025]. +This algorithm operates under the square-root or Cholesky factorization of the covariance matrix parameterization. +This contrasts with [`KLMinNaturalGradDescent`](@ref klminnaturalgraddescent), which operates in the precision matrix parameterization, requiring a matrix inverse at each step. +As a result, the cost of `KLMinSqrtNaturalGradDescent` should be relatively cheaper. +Since `KLMinSqrtNaturalGradDescent` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that make the updates tractable. + +```@docs +KLMinSqrtNaturalGradDescent +``` + +The associated objective value, which is the ELBO, can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^KMKL2025]: Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. *Transactions of Machine Learning Research*. +[^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*. +[^LDLNKS2023]: Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In *International Conference on Machine Learning*. +[^T2025]: Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. *Journal of the Royal Statistical Society: Series B.* + +## [Methodology](@id klminsqrtnaturalgraddescent_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) +``` +where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. +That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. + +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, the ELBO maximization strategy minimizes a surrogate objective, the *negative evidence lower bound*[^JGJS1999] + +```math + \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), +``` +which is equivalent to the KL up to an additive constant (the evidence). + +While `KLMinSqrtNaturalGradDescent` is close to a natural gradient variational inference algorithm, it can be derived in a variety of different ways. +In fact, the update rule has been concurrently developed by several research groups[^KMKL2025][^LDEBTM2024][^LDLNKS2023][^T2025]. +Here, we will present the derivation by Kumar *et al.* [^KMKL2025]. +Consider the ideal natural gradient descent algorithm discussed [here](@ref klminnaturalgraddescent_method). +This can be viewed as a discretization of the continuous-time dynamics given by the differential equation + +```math +\dot{\lambda}_t += +{F(\lambda)}^{-1} \nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right) . +``` + +This is also known as the *natural gradient flow*. +Notice that the flow is over the parameters $\lambda_t$. +Therefore, the natural gradient flow depends on the way we parametrize $q_{\lambda}$. +For Gaussian variational families, if we specifically choose the *square-root* (or Cholesky) parametrization such that $q_{\lambda_t} = \mathrm{Normal}(m_t, C_t C_t)$, the flow of $\lambda_t = (m_t, C_t)$ given as + +```math +\begin{align*} +\dot{m}_t &= C_t C_t^{\top} \mathbb{E}_{q_{\lambda_t}} \left[ \nabla \log \pi \right] +\\ +\dot{C}_t &= C_t M\left( \mathrm{I}_d + C_t^{\top} \mathbb{E}\left[ \nabla^2 \log \pi \right] C_t \right) , +\end{align*} +``` +where $M$ is a $\mathrm{tril}$-like function defined as +```math +{[ M(A) ]}_{ij} = \begin{cases} + 0 & \text{if $i > j$} \\ + \frac{1}{2} A_{ii} & \text{if $i = j$} \\ + A_{ij} & \text{if $i < j$} . +\end{cases} +``` + +`KLMinSqrtNaturalGradDescent` corresponds to the forward Euler discretization of this flow. + +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. From 58d5cf5b3c63c808fed563f4488e496a83f896af Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:05:02 -0500 Subject: [PATCH 21/32] fix docstrings for the measure-space algorithms --- src/algorithms/klminnaturalgraddescent.jl | 8 ++++---- src/algorithms/klminsqrtnaturalgraddescent.jl | 8 ++++---- src/algorithms/klminwassfwdbwd.jl | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl index d465b6037..e47be81cf 100644 --- a/src/algorithms/klminnaturalgraddescent.jl +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -9,9 +9,9 @@ This algorithm can be viewed as an instantiation of mirror descent, where the Br If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. This, however, involves an additional set of matrix-matrix system solves that could be costly. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `stepsize::Float64`: Step size. @@ -38,7 +38,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl index a26af011c..a8965baaf 100644 --- a/src/algorithms/klminsqrtnaturalgraddescent.jl +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -5,9 +5,9 @@ KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `stepsize::Float64`: Step size. @@ -33,7 +33,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index e04184fb4..570321ac6 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,9 +5,9 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. -The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. -If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. -If the target has only first-order capability, we use Stein's identity. +This algorithm requires second-order information about the target. +If the target `LogDensityProblem` has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), Hessians are used. +Otherwise, if the target has only first-order capability, it will use only gradients but this will porbably result in slower convergence and less robust behavior. # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) @@ -33,7 +33,7 @@ The keyword arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). -- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target distribution has unconstrained support. - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <: From d6a29d23c5adee63212fa7afe588a4e49556e3a2 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:17:45 -0500 Subject: [PATCH 22/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminnaturalgraddescent.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md index 482fefaa5..e2a603616 100644 --- a/docs/src/klminnaturalgraddescent.md +++ b/docs/src/klminnaturalgraddescent.md @@ -26,7 +26,6 @@ estimate_objective( [^KR2023]: Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. *Journal of Machine Learning Research*, 24(281), 1-46. [^LSK2020]: Lin, W., Schmidt, M., & Khan, M. E. (2020). Handling the positive-definite constraint in the Bayesian learning rule. In *International Conference on Machine Learning*. PMLR. - ## [Methodology](@id klminnaturalgraddescent_method) This algorithm aims to solve the problem From 45c375187a30a92c98fe4e81a26505152d1e7daf Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:17:52 -0500 Subject: [PATCH 23/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md index e2a603616..43bbc68ba 100644 --- a/docs/src/klminnaturalgraddescent.md +++ b/docs/src/klminnaturalgraddescent.md @@ -33,6 +33,7 @@ This algorithm aims to solve the problem ```math \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) ``` + where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. From 3395851e3d93e358f04a8c4d3c9096bd9e193dee Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:17:58 -0500 Subject: [PATCH 24/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md index 43bbc68ba..bbcb46116 100644 --- a/docs/src/klminnaturalgraddescent.md +++ b/docs/src/klminnaturalgraddescent.md @@ -43,6 +43,7 @@ Instead, the ELBO maximization strategy minimizes a surrogate objective, the *ne ```math \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), ``` + which is equivalent to the KL up to an additive constant (the evidence). Suppose we had access to the exact gradients $\nabla_{\lambda} \mathcal{L}\left(q_{\lambda}\right)$. From 862fc609a92e80c9a9e2c1127e05229b5ec90be9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:04 -0500 Subject: [PATCH 25/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminnaturalgraddescent.md b/docs/src/klminnaturalgraddescent.md index bbcb46116..011e3b804 100644 --- a/docs/src/klminnaturalgraddescent.md +++ b/docs/src/klminnaturalgraddescent.md @@ -58,6 +58,7 @@ This turns out to be equivalent to the update ```math \lambda_{t+1} = \lambda_{t} - \gamma_t {F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t}) , ``` + where $F(\lambda_t)$ is the Fisher information matrix of $q_{\lambda}$. That is, natural gradient descent can be viewed as gradient descent with an iterate-dependent preconditioning. Furthermore, ${F(\lambda_t)}^{-1} \nabla_{\lambda} \mathcal{L}(q_{\lambda_t})$ is refered to as the *natural gradient* of the KL divergence[^A1998], hence natural gradient variational inference. From 36a4be7417a0754a7025fcf49f4b4c07e5da0753 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:10 -0500 Subject: [PATCH 26/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index 17e2e32f7..41aa69ebb 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -1,4 +1,3 @@ - # [`KLMinSqrtNaturalGradDescent`](@id klminsqrtnaturalgraddescent) ## Description From 76775c5c3ccd1752e8af1da2e21cdc56c30365d7 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:25 -0500 Subject: [PATCH 27/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index 41aa69ebb..6c9c38a5d 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -25,8 +25,8 @@ estimate_objective( ) ``` -[^KMKL2025]: Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. *Transactions of Machine Learning Research*. -[^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*. +[^KMKL2025]: Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. *Transactions of Machine Learning Research*. +[^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*. [^LDLNKS2023]: Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In *International Conference on Machine Learning*. [^T2025]: Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. *Journal of the Royal Statistical Society: Series B.* From d72f61cc621e43f4d2a47c2f2936b02fdfa48958 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:35 -0500 Subject: [PATCH 28/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index 6c9c38a5d..fa4feb8ac 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -29,7 +29,6 @@ estimate_objective( [^LDEBTM2024]: Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. In *International Conference on Machine Learning*. [^LDLNKS2023]: Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. In *International Conference on Machine Learning*. [^T2025]: Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. *Journal of the Royal Statistical Society: Series B.* - ## [Methodology](@id klminsqrtnaturalgraddescent_method) This algorithm aims to solve the problem From 87de3192f16f2ab571518ce9be0642cfa9303a43 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:43 -0500 Subject: [PATCH 29/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index fa4feb8ac..a76f3fbe5 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -36,6 +36,7 @@ This algorithm aims to solve the problem ```math \mathrm{minimize}_{q_{\lambda} \in \mathcal{Q}}\quad \mathrm{KL}\left(q_{\lambda}, \pi\right) ``` + where $\mathcal{Q}$ is some family of distributions, often called the variational family, by running stochastic gradient descent in the (Euclidean) space of parameters. That is, for all $$q_{\lambda} \in \mathcal{Q}$$, we assume $$q_{\lambda}$$ there is a corresponding vector of parameters $$\lambda \in \Lambda$$, where the space of parameters is Euclidean such that $$\Lambda \subset \mathbb{R}^p$$. From 1c6226aa45515735121109cfdb267719fd659d92 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:48 -0500 Subject: [PATCH 30/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index a76f3fbe5..b86596692 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -46,6 +46,7 @@ Instead, the ELBO maximization strategy minimizes a surrogate objective, the *ne ```math \mathcal{L}\left(q\right) \triangleq \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) - \mathbb{H}\left(q\right), ``` + which is equivalent to the KL up to an additive constant (the evidence). While `KLMinSqrtNaturalGradDescent` is close to a natural gradient variational inference algorithm, it can be derived in a variety of different ways. From 248112b3fa1b080a3d10de409dcf8b4b39fba311 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:18:53 -0500 Subject: [PATCH 31/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index b86596692..cb47f1af8 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -73,6 +73,7 @@ For Gaussian variational families, if we specifically choose the *square-root* ( \dot{C}_t &= C_t M\left( \mathrm{I}_d + C_t^{\top} \mathbb{E}\left[ \nabla^2 \log \pi \right] C_t \right) , \end{align*} ``` + where $M$ is a $\mathrm{tril}$-like function defined as ```math {[ M(A) ]}_{ij} = \begin{cases} From 51a0b3f6599fc5c3140337922ed55b7b97e5d61a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Nov 2025 14:19:19 -0500 Subject: [PATCH 32/32] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminsqrtnaturalgraddescent.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminsqrtnaturalgraddescent.md b/docs/src/klminsqrtnaturalgraddescent.md index cb47f1af8..1d5c45b7f 100644 --- a/docs/src/klminsqrtnaturalgraddescent.md +++ b/docs/src/klminsqrtnaturalgraddescent.md @@ -75,6 +75,7 @@ For Gaussian variational families, if we specifically choose the *square-root* ( ``` where $M$ is a $\mathrm{tril}$-like function defined as + ```math {[ M(A) ]}_{ij} = \begin{cases} 0 & \text{if $i > j$} \\