Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
342eaf5
refactor split shared batch reshuffling into `src/reshuffling.jl`
Red-Portal Oct 25, 2025
ba28e2e
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into wasse…
Red-Portal Oct 25, 2025
f9ddebb
refactor move `SubsampledNormals` code under `test/models/`
Red-Portal Oct 25, 2025
563bb5d
fix wrong capability for subsamplednormals
Red-Portal Oct 25, 2025
6674b42
fix remove unused fields
Red-Portal Oct 25, 2025
bd0b889
fix wrong variable name, add missing rng argument
Red-Portal Oct 25, 2025
fec5055
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into wasse…
Red-Portal Oct 25, 2025
bb2fa3a
fix missing `dimension` and wrong variance in subsampled normals
Red-Portal Oct 25, 2025
7b6246f
add Wasserstein VI algorithm
Red-Portal Oct 25, 2025
a3201f8
add tests for Wasserstein VI
Red-Portal Oct 25, 2025
f478e15
add docs for Wasserstein VI
Red-Portal Oct 25, 2025
9583ef9
fix formatting
Red-Portal Oct 25, 2025
ba31866
fix formatting
Red-Portal Oct 25, 2025
455118f
fix formatting
Red-Portal Oct 25, 2025
8321555
fix typos
Red-Portal Oct 26, 2025
c985932
add optional Stein's identity expected Hessian estimator
Red-Portal Oct 29, 2025
17c1069
fix test also capability in subsampling convergence test
Red-Portal Oct 29, 2025
283961f
update docs remove comment that hessian is required for fwdbwdwass
Red-Portal Oct 29, 2025
724103c
update docs
Red-Portal Oct 29, 2025
f8ca2be
apply formatter
Red-Portal Oct 29, 2025
b478db9
apply formatter
Red-Portal Oct 29, 2025
176726d
apply formatter
Red-Portal Oct 29, 2025
53522c9
apply formatter
Red-Portal Oct 29, 2025
e21084c
update docs
Red-Portal Oct 30, 2025
fd36d38
update docs
Red-Portal Oct 30, 2025
5ef5805
run formatter
Red-Portal Oct 30, 2025
ff67c2e
refactor test for `KLMinWassFwdBwd`
Red-Portal Oct 30, 2025
2bd3bc8
run formatter
Red-Portal Oct 30, 2025
85dcaa9
improve docstrings callback section
Red-Portal Oct 30, 2025
4b34875
bump patch version
Red-Portal Nov 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Release 0.5.1

This update adds new variational inference algorithms in light of the flexibility added in the v0.5 update.
Specifically, the following measure-space optimization algorithms have been added:

- `KLMinWassFwdBwd`

# Release 0.5

## Default Configuration Changes
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.5.0"
version = "0.5.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
7 changes: 4 additions & 3 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ makedocs(;
"Normalizing Flows" => "tutorials/flows.md",
],
"Algorithms" => [
"KLMinRepGradDescent" => "klminrepgraddescent.md",
"KLMinRepGradProxDescent" => "klminrepgradproxdescent.md",
"KLMinScoreGradDescent" => "klminscoregraddescent.md",
"`KLMinRepGradDescent`" => "klminrepgraddescent.md",
"`KLMinRepGradProxDescent`" => "klminrepgradproxdescent.md",
"`KLMinScoreGradDescent`" => "klminscoregraddescent.md",
"`KLMinWassFwdBwd`" => "klminwassfwdbwd.md",
],
"Variational Families" => "families.md",
"Optimization" => "optimization.md",
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ For using the algorithms implemented in `AdvancedVI`, refer to the corresponding
- [KLMinRepGradDescent](@ref klminrepgraddescent) (alias of `ADVI`)
- [KLMinRepGradProxDescent](@ref klminrepgradproxdescent)
- [KLMinScoreGradDescent](@ref klminscoregraddescent) (alias of `BBVI`)
- [KLMinWassFwdBwd](@ref klminwassfwdbwd)
77 changes: 77 additions & 0 deletions docs/src/klminwassfwdbwd.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# [`KLMinWassFwdBwd`](@id klminwassfwdbwd)

## Description

This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running proximal gradient descent (also known as forward-backward splitting) in Wasserstein space[^DBCS2023].
(This algorithm is also sometimes referred to as "Wasserstein VI".)
Since `KLMinWassFwdBwd` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that makes the measure-valued operations tractable.

```@docs
KLMinWassFwdBwd
```

The associated objective value, which is the ELBO, can be estimated through the following:

```@docs; canonical=false
estimate_objective(
::Random.AbstractRNG,
::KLMinWassFwdBwd,
::MvLocationScale,
::Any;
::Int,
)
```

[^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In *International Conference on Machine Learning*. PMLR.
## [Methodology](@id klminwassfwdbwd_method)

This algorithm aims to solve the problem

```math
\mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right)
```

where $\mathcal{Q}$ is some family of distributions, often called the variational family.
Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence.
Instead, we focus on minimizing a surrogate objective, the *free energy functional*, which corresponds to the negated evidence lower bound[^JGJS1999], defined as

```math
\mathcal{F}\left(q\right) \triangleq \mathcal{U}\left(q\right) + \mathcal{H}\left(q\right),
```

where

```math
\begin{aligned}
\mathcal{U}\left(q\right) &= \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right)
&&\text{(``potential energy'')}
\\
\mathcal{H}\left(q\right) &= \mathbb{E}_{\theta \sim q} \log q\left(\theta\right) .
&&\text{(``Boltzmann entropy'')}
\end{aligned}
```

For solving this problem, `KLMinWassFwdBwd` relies on proximal stochastic gradient descent (PSGD)---also known as "forward-backward splitting"---that iterates

```math
q_{t+1} = \mathrm{JKO}_{\gamma_t \mathcal{H}}\big(
q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{U}} (q_{t})
\big) ,
```

where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{U}}$$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $$\mathcal{U}$$, the JKO (proximal) operator is defined as

```math
\mathrm{JKO}_{\gamma_t \mathcal{H}}(\mu)
=
\argmin_{\nu \in \mathcal{Q}} \left\{ \mathcal{H}(\nu) + \frac{1}{2 \gamma_t} \mathrm{W}_2 {(\mu, \nu)}^2 \right\} ,
```

and $$\mathrm{W}_2$$ is the Wasserstein-2 distance.
When $$\mathcal{Q}$$ is set to be the Bures-Wasserstein space of $$\mathbb{R}^d$$, this algorithm is referred to as the Jordan-Kinderlehrer-Otto (JKO) scheme[^JKO1998], which was originally developed to study gradient flows under Wasserstein metrics.
Within this context, `KLMinWassFwdBwd` can be viewed as a numerical realization of the JKO scheme by restricting $$\mathcal{Q}$$ to be a tractable parametric variational family.
Specifically, Diao *et al.*[^DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by `KLMinWassFwdBwd`.
`KLMinWassFwdBwd` also exactly corresponds to the measure-space analog of [KLMinRepGradProxDescent](@ref klminrepgradproxdescent).

[^JKO1998]: Jordan, R., Kinderlehrer, D., & Otto, F. (1998). The variational formulation of the Fokker--Planck equation. *SIAM Journal on Mathematical Analysis*, 29(1).
[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233.
6 changes: 6 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,4 +352,10 @@ include("algorithms/common.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

# Other Algorithms

include("algorithms/klminwassfwdbwd.jl")

export KLMinWassFwdBwd

end
16 changes: 8 additions & 8 deletions src/algorithms/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ KL divergence minimization by running stochastic gradient descent with the repar
# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of
# Callback Signature
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
Expand Down Expand Up @@ -100,12 +100,12 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed.
# Output
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of
# Callback Signature
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
Expand Down Expand Up @@ -178,11 +178,11 @@ KL divergence minimization by running stochastic gradient descent with the score
- `q_averaged`: The variational approximation formed by the averaged SGD iterates.

# Callback
The callback function `callback` has a signature of
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient)

The arguments are as follows:
The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation.
Expand Down
213 changes: 213 additions & 0 deletions src/algorithms/klminwassfwdbwd.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@

"""
KLMinWassFwdBwd(n_samples, stepsize, subsampling)
KLMinWassFwdBwd(; n_samples, stepsize, subsampling)

KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023].

Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity.

# (Keyword) Arguments
- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`)
- `stepsize::Float64`: Step size of stochastic proximal gradient descent.
- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy.

!!! note
The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinWassFwdBwd` does not support amortization or structured variational families.

# Output
- `q`: The last iterate of the algorithm.

# Callback Signature
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `q`: Current variational approximation.
- `info`: `NamedTuple` containing the information generated during the current iteration.

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
AbstractVariationalAlgorithm
n_samples::Int = 1
stepsize::Float64
subsampling::Sub = nothing
end

"""
gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob)

Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over.
- `n_samples::Int`: Number of samples used for estimation.
- `grad_buf::AbstractVector`: Buffer for the gradient estimate.
- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate.
- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation.
"""
function gaussian_expectation_gradient_and_hessian!(
rng::Random.AbstractRNG,
q::MvLocationScale{<:LowerTriangular,<:Normal,L},
n_samples::Int,
grad_buf::AbstractVector{T},
hess_buf::AbstractMatrix{T},
prob,
) where {T<:Real,L}
logπ_avg = zero(T)
fill!(grad_buf, zero(T))
fill!(hess_buf, zero(T))

if LogDensityProblems.capabilities(typeof(prob)) ≤
LogDensityProblems.LogDensityOrder{1}()
# Use Stein's identity
d = LogDensityProblems.dimension(prob)
u = randn(rng, T, d, n_samples)
z = q.scale*u .+ q.location
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ/n_samples
grad_buf += ∇logπ/n_samples
hess_buf += ub*(∇logπ/n_samples)'
end
return logπ_avg, grad_buf, hess_buf
else
# Use sample average of the Hessian.
z = rand(rng, q, n_samples)
for b in 1:n_samples
zb = view(z, :, b)
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
prob, zb
)
logπ_avg += logπ/n_samples
grad_buf += ∇logπ/n_samples
hess_buf += ∇2logπ/n_samples
end
return logπ_avg, grad_buf, hess_buf
end
end

struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf}
q::Q
prob::P
sigma::Sigma
iteration::Int
sub_st::S
grad_buf::GradBuf
hess_buf::HessBuf
end

function init(
rng::Random.AbstractRNG,
alg::KLMinWassFwdBwd,
q_init::MvLocationScale{<:LowerTriangular,<:Normal,L},
prob,
) where {L}
sub = alg.subsampling
n_dims = LogDensityProblems.dimension(prob)
capability = LogDensityProblems.capabilities(typeof(prob))
if capability < LogDensityProblems.LogDensityOrder{1}()
throw(
ArgumentError(
"`KLMinWassFwdBwd` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).",
),
)
end
sub_st = isnothing(sub) ? nothing : init(rng, sub)
grad_buf = Vector{eltype(q_init.location)}(undef, n_dims)
hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims)
return KLMinWassFwdBwdState(q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf)
end

output(::KLMinWassFwdBwd, state) = state.q

function step(
rng::Random.AbstractRNG, alg::KLMinWassFwdBwd, state, callback, objargs...; kwargs...
)
(; n_samples, stepsize, subsampling) = alg
(; q, prob, sigma, iteration, sub_st, grad_buf, hess_buf) = state

m = mean(q)
Σ = sigma
η = convert(eltype(m), stepsize)
iteration += 1

# Maybe apply subsampling
prob_sub, sub_st′, sub_inf = if isnothing(subsampling)
prob, sub_st, NamedTuple()
else
batch, sub_st′, sub_inf = step(rng, subsampling, sub_st)
prob_sub = subsample(prob, batch)
prob_sub, sub_st′, sub_inf
end

# Estimate the Wasserstein gradient
logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!(
rng, q, n_samples, grad_buf, hess_buf, prob_sub
)

m′ = m - η * (-grad_buf)
M = I - η*Hermitian(-hess_buf)
Σ_half = Hermitian(M*Σ*M)

# Compute the JKO proximal operator
Σ′ = (Σ_half + 2*η*I + sqrt(Hermitian(Σ_half*(Σ_half + 4*η*I))))/2
q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist)

state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf)
elbo = logπ_avg + entropy(q′)
info = merge((elbo=elbo,), sub_inf)

if !isnothing(callback)
info′ = callback(; rng, iteration, q, info)
info = !isnothing(info′) ? merge(info′, info) : info
end
state, false, info
end

"""
estimate_objective([rng,] alg, q, prob; n_samples)

Estimate the ELBO of the variational approximation `q` against the target log-density `prob`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::KLMinWassFwdBwd`: Variational inference algorithm.
- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.

# Keyword Arguments
- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)

# Returns
- `obj_est`: Estimate of the objective value.
"""
function estimate_objective(
rng::Random.AbstractRNG,
alg::KLMinWassFwdBwd,
q::MvLocationScale{S,<:Normal,L},
prob;
n_samples::Int=alg.n_samples,
) where {S,L}
obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy())
if isnothing(alg.subsampling)
return estimate_objective(rng, obj, q, prob)
else
sub = alg.subsampling
sub_st = init(rng, sub)
return mapreduce(+, 1:length(sub)) do _
batch, sub_st, _ = step(rng, sub, sub_st)
prob_sub = subsample(prob, batch)
estimate_objective(rng, obj, q, prob_sub) / length(sub)
end
end
end
Loading
Loading