Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
1980f93
move gaussian expectation of grad and hess to its own file
Red-Portal Nov 7, 2025
c84b453
add square-root variational newton algorithm
Red-Portal Nov 7, 2025
48daaa0
apply formatter
Red-Portal Nov 7, 2025
3483e8d
add natural gradient descent (variational online Newton)
Red-Portal Nov 7, 2025
8267a98
update docstrings remove redundant comments
Red-Portal Nov 7, 2025
f3790c3
run formatter
Red-Portal Nov 7, 2025
3ba8401
update history
Red-Portal Nov 7, 2025
bca8f55
fix gauss expected grad hess, use in-place operations, add tests
Red-Portal Nov 11, 2025
82e9f15
fix always wrap `hess_buf` with a `Symmetric` (not `Hermitian`)
Red-Portal Nov 11, 2025
8fdecb1
Apply suggestion from @sunxd3
Red-Portal Nov 11, 2025
3ff2c0f
Apply suggestion from @sunxd3
Red-Portal Nov 11, 2025
45b9989
Apply suggestion from @github-actions[bot]
Red-Portal Nov 11, 2025
75e489d
fix bug in init of klminnaturalgraddescent
Red-Portal Nov 11, 2025
6f55a5c
remove unintended benchmark code
Red-Portal Nov 11, 2025
78d3559
update docs
Red-Portal Nov 11, 2025
020634d
fix relax Hermitian to Symmetric in NGVI ensure posdef
Red-Portal Nov 11, 2025
ccf6506
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 11, 2025
9070c82
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 12, 2025
47c8ed2
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into natur…
Red-Portal Nov 14, 2025
321fb92
fix gauss expected grad hess
Red-Portal Nov 14, 2025
f7f965a
fix callback argument in measure space algorithms
Red-Portal Nov 14, 2025
49236af
fix the positive definite preserving update rule in NGVI
Red-Portal Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This update adds new variational inference algorithms in light of the flexibilit
Specifically, the following measure-space optimization algorithms have been added:

- `KLMinWassFwdBwd`
- `KLMinNaturalGradDescent`
- `KLMinSqrtNaturalGradDescent`

In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` will now throw a `RuntimException` if the objective value estimated at each step turns out to be degenerate (`Inf` or `NaN`). Previously, the algorithms ran until `max_iter` even if the optimization run has failed.

Expand Down
7 changes: 5 additions & 2 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,13 @@ include("algorithms/common.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

# Other Algorithms
# Natural and Wasserstein gradient descent algorithms

include("algorithms/gauss_expected_grad_hess.jl")
include("algorithms/klminwassfwdbwd.jl")
include("algorithms/klminsqrtnaturalgraddescent.jl")
include("algorithms/klminnaturalgraddescent.jl")

export KLMinWassFwdBwd
export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent

end
80 changes: 80 additions & 0 deletions src/algorithms/gauss_expected_grad_hess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

"""
gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob)

Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`.
For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian.
Otherwise, it uses Stein's identity.

!!! warning
The resulting `hess_buf` may not be perfectly symmetric due to numerical issues. It is therefore useful to wrap it in a `Symmetric` before usage.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over.
- `n_samples::Int`: Number of samples used for estimation.
- `grad_buf::AbstractVector`: Buffer for the gradient estimate.
- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate.
- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation.
"""
function gaussian_expectation_gradient_and_hessian!(
rng::Random.AbstractRNG,
q::MvLocationScale{<:LinearAlgebra.AbstractTriangular,<:Normal,L},
n_samples::Int,
grad_buf::AbstractVector{T},
hess_buf::AbstractMatrix{T},
prob,
) where {T<:Real,L}
logπ_avg = zero(T)
fill!(grad_buf, zero(T))
fill!(hess_buf, zero(T))

if LogDensityProblems.capabilities(typeof(prob)) ≤
LogDensityProblems.LogDensityOrder{1}()
# First-order-only: use Stein/Price identity for the Hessian
#
# E_{z ~ N(m, CC')} ∇2 log π(z)
# = E_{z ~ N(m, CC')} (CC')^{-1} (z - m) ∇ log π(z)T
# = E_{u ~ N(0, I)} C' \ (u ∇ log π(z)T) .
#
# Algorithmically, draw u ~ N(0, I), z = C u + m, where C = q.scale.
# Accumulate A = E[ u ∇ log π(z)T ], then map back: H = C \ A.
d = LogDensityProblems.dimension(prob)
u = randn(rng, T, d, n_samples)
m, C = q.location, q.scale
z = C*u .+ m
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ/n_samples

rdiv!(∇logπ, n_samples)
∇logπ_div_nsamples = ∇logπ

grad_buf[:] .+= ∇logπ_div_nsamples
hess_buf[:, :] .+= ub*∇logπ_div_nsamples'
end
hess_buf[:, :] .= C' \ hess_buf
return logπ_avg, grad_buf, hess_buf
else
# Second-order: use naive sample average
z = rand(rng, q, n_samples)
for b in 1:n_samples
zb = view(z, :, b)
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
prob, zb
)

rdiv!(∇logπ, n_samples)
∇logπ_div_nsamples = ∇logπ

rdiv!(∇2logπ, n_samples)
∇2logπ_div_nsamples = ∇2logπ

logπ_avg += logπ/n_samples
grad_buf[:] .+= ∇logπ_div_nsamples
hess_buf[:, :] .+= ∇2logπ_div_nsamples
end
return logπ_avg, grad_buf, hess_buf
end
end
190 changes: 190 additions & 0 deletions src/algorithms/klminnaturalgraddescent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@

"""
KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling)
KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling)

KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton.
This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence.

If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness.
This, however, involves an additional set of matrix-matrix system solves that could be costly.

The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation.
If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian.
If the target has only first-order capability, we use Stein's identity.

# (Keyword) Arguments
- `stepsize::Float64`: Step size.
- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`)
- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`)
- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy.

!!! note
The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families.

# Output
- `q`: The last iterate of the algorithm.

# Callback Signature
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `q`: Current variational approximation.
- `info`: `NamedTuple` containing the information generated during the current iteration.

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
AbstractVariationalAlgorithm
stepsize::Float64
n_samples::Int = 1
ensure_posdef::Bool = true
subsampling::Sub = nothing
end

struct KLMinNaturalGradDescentState{Q,P,S,Prec,QCov,GradBuf,HessBuf}
q::Q
prob::P
prec::Prec
qcov::QCov
iteration::Int
sub_st::S
grad_buf::GradBuf
hess_buf::HessBuf
end

function init(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
q_init::MvLocationScale{<:LowerTriangular,<:Normal,L},
prob,
) where {L}
sub = alg.subsampling
n_dims = LogDensityProblems.dimension(prob)
capability = LogDensityProblems.capabilities(typeof(prob))
if capability < LogDensityProblems.LogDensityOrder{1}()
throw(
ArgumentError(
"`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).",
),
)
end
sub_st = isnothing(sub) ? nothing : init(rng, sub)
grad_buf = Vector{eltype(q_init.location)}(undef, n_dims)
hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims)
scale = q_init.scale
qcov = Hermitian(scale*scale')
scale_inv = inv(scale)
prec_chol = scale_inv'
prec = Hermitian(prec_chol*prec_chol')
return KLMinNaturalGradDescentState(
q_init, prob, prec, qcov, 0, sub_st, grad_buf, hess_buf
)
end

output(::KLMinNaturalGradDescent, state) = state.q

function step(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
state,
callback,
objargs...;
kwargs...,
)
(; ensure_posdef, n_samples, stepsize, subsampling) = alg
(; q, prob, prec, qcov, iteration, sub_st, grad_buf, hess_buf) = state

m = mean(q)
S = prec
η = convert(eltype(m), stepsize)
iteration += 1

# Maybe apply subsampling
prob_sub, sub_st′, sub_inf = if isnothing(subsampling)
prob, sub_st, NamedTuple()
else
batch, sub_st′, sub_inf = step(rng, subsampling, sub_st)
prob_sub = subsample(prob, batch)
prob_sub, sub_st′, sub_inf
end

logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!(
rng, q, n_samples, grad_buf, hess_buf, prob_sub
)

S′ = if ensure_posdef
# Udpate rule guaranteeing positive definiteness in the proof of Theorem 1.
# Lin, W., Schmidt, M., & Khan, M. E.
# Handling the positive-definite constraint in the Bayesian learning rule.
# In ICML 2020.
G_hat = S - Symmetric(-hess_buf)
Hermitian(S - η*G_hat + η^2/2*G_hat*qcov*G_hat)
else
Hermitian(((1 - η) * S + η * Symmetric(-hess_buf)))
end
m′ = m - η * (S′ \ (-grad_buf))

prec_chol = cholesky(S′).L
prec_chol_inv = inv(prec_chol)
scale = prec_chol_inv'
qcov = Hermitian(scale*scale')
q′ = MvLocationScale(m′, scale, q.dist)

state = KLMinNaturalGradDescentState(
q′, prob, S′, qcov, iteration, sub_st′, grad_buf, hess_buf
)
elbo = logπ_avg + entropy(q′)
info = merge((elbo=elbo,), sub_inf)

if !isnothing(callback)
info′ = callback(; rng, iteration, q=q′, info)
info = !isnothing(info′) ? merge(info′, info) : info
end
state, false, info
end

"""
estimate_objective([rng,] alg, q, prob; n_samples)

Estimate the ELBO of the variational approximation `q` against the target log-density `prob`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::KLMinNaturalGradDescent`: Variational inference algorithm.
- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.

# Keyword Arguments
- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)

# Returns
- `obj_est`: Estimate of the objective value.
"""
function estimate_objective(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
q::MvLocationScale{S,<:Normal,L},
prob;
n_samples::Int=alg.n_samples,
) where {S,L}
obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy())
if isnothing(alg.subsampling)
return estimate_objective(rng, obj, q, prob)
else
sub = alg.subsampling
sub_st = init(rng, sub)
return mapreduce(+, 1:length(sub)) do _
batch, sub_st, _ = step(rng, sub, sub_st)
prob_sub = subsample(prob, batch)
estimate_objective(rng, obj, q, prob_sub) / length(sub)
end
end
end
Loading
Loading