From 342eaf5b1a2e34230eba667f95de26b0b06b6920 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 03:53:08 -0400 Subject: [PATCH 01/28] refactor split shared batch reshuffling into `src/reshuffling.jl` --- src/AdvancedVI.jl | 2 + src/algorithms/subsampledobjective.jl | 53 --------------------------- src/reshuffling.jl | 51 ++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 src/reshuffling.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 1df3fd60a..8fb85d7ca 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -297,6 +297,8 @@ subsample(model_or_q::Any, ::Any) = model_or_q abstract type AbstractSubsampling end +include("reshuffling.jl") + export ReshufflingBatchSubsampling # Main optimization routine diff --git a/src/algorithms/subsampledobjective.jl b/src/algorithms/subsampledobjective.jl index a904eb245..b71e7ab47 100644 --- a/src/algorithms/subsampledobjective.jl +++ b/src/algorithms/subsampledobjective.jl @@ -1,56 +1,3 @@ - -""" - ReshufflingBatchSubsampling(dataset, batchsize) - -Random reshuffling subsampling strategy. -At each 'epoch', this strategy splits the dataset into batches of `batchsize`, shuffles the order of the batches, and goes through each of them in that order. -Processing a single batch is referred as a 'step.' -At the end of each epoch, which means we've gone through all the batches exactly once, we repeat the whole process. - -# Arguments -- `dataset::AbstractVector`: Iterable sequence representing the dataset. -- `batchsize::Int`: Number of data points in each batch. If the number of data points is not exactly dividable by `batchsize`, the last batch may contain less data points than `batchsize`. -""" -struct ReshufflingBatchSubsampling{DataSet<:AbstractVector} <: AbstractSubsampling - dataset::DataSet - batchsize::Int -end - -struct ReshufflingBatchSubsamplingState{It} - epoch::Int - iterator::It -end - -Base.length(sub::ReshufflingBatchSubsampling) = ceil(Int, length(sub.dataset)/sub.batchsize) - -function reshuffle_batches(rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling) - (; dataset, batchsize) = sub - shuffled = Random.shuffle(rng, dataset) - batches = Iterators.partition(shuffled, batchsize) - return enumerate(batches) -end - -function init(rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling) - return ReshufflingBatchSubsamplingState(1, reshuffle_batches(rng, sub)) -end - -function step( - rng::Random.AbstractRNG, - sub::ReshufflingBatchSubsampling, - state::ReshufflingBatchSubsamplingState, -) - (; epoch, iterator) = state - (sub_step, batch), batch_it′ = Iterators.peel(iterator) - epoch′, iterator′′ = if isempty(batch_it′) - epoch + 1, reshuffle_batches(rng, sub) - else - epoch, batch_it′ - end - info = (epoch=epoch, step=sub_step) - state′ = ReshufflingBatchSubsamplingState(epoch′, iterator′′) - return batch, state′, info -end - """ SubsampledObjective(objective, subsampling) diff --git a/src/reshuffling.jl b/src/reshuffling.jl new file mode 100644 index 000000000..a1ffdece1 --- /dev/null +++ b/src/reshuffling.jl @@ -0,0 +1,51 @@ +""" + ReshufflingBatchSubsampling(dataset, batchsize) + +Random reshuffling subsampling strategy. +At each 'epoch', this strategy splits the dataset into batches of `batchsize`, shuffles the order of the batches, and goes through each of them in that order. +Processing a single batch is referred as a 'step.' +At the end of each epoch, which means we've gone through all the batches exactly once, we repeat the whole process. + +# Arguments +- `dataset::AbstractVector`: Iterable sequence representing the dataset. +- `batchsize::Int`: Number of data points in each batch. If the number of data points is not exactly dividable by `batchsize`, the last batch may contain less data points than `batchsize`. +""" +struct ReshufflingBatchSubsampling{DataSet<:AbstractVector} <: AbstractSubsampling + dataset::DataSet + batchsize::Int +end + +struct ReshufflingBatchSubsamplingState{It} + epoch::Int + iterator::It +end + +Base.length(sub::ReshufflingBatchSubsampling) = ceil(Int, length(sub.dataset)/sub.batchsize) + +function reshuffle_batches(rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling) + (; dataset, batchsize) = sub + shuffled = Random.shuffle(rng, dataset) + batches = Iterators.partition(shuffled, batchsize) + return enumerate(batches) +end + +function init(rng::Random.AbstractRNG, sub::ReshufflingBatchSubsampling) + return ReshufflingBatchSubsamplingState(1, reshuffle_batches(rng, sub)) +end + +function step( + rng::Random.AbstractRNG, + sub::ReshufflingBatchSubsampling, + state::ReshufflingBatchSubsamplingState, +) + (; epoch, iterator) = state + (sub_step, batch), batch_it′ = Iterators.peel(iterator) + epoch′, iterator′′ = if isempty(batch_it′) + epoch + 1, reshuffle_batches(rng, sub) + else + epoch, batch_it′ + end + info = (epoch=epoch, step=sub_step) + state′ = ReshufflingBatchSubsamplingState(epoch′, iterator′′) + return batch, state′, info +end From f9ddebbb3c1402e7ea0b4f888060b82efc247204 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 06:47:21 -0400 Subject: [PATCH 02/28] refactor move `SubsampledNormals` code under `test/models/` --- test/algorithms/subsampledobj.jl | 62 ++++++++------------------------ test/models/subsamplednormals.jl | 50 ++++++++++++++++++++++++++ test/runtests.jl | 2 +- 3 files changed, 66 insertions(+), 48 deletions(-) create mode 100644 test/models/subsamplednormals.jl diff --git a/test/algorithms/subsampledobj.jl b/test/algorithms/subsampledobj.jl index c5b8720ec..9c221a9fb 100644 --- a/test/algorithms/subsampledobj.jl +++ b/test/algorithms/subsampledobj.jl @@ -1,44 +1,12 @@ -struct SubsampledNormals{D<:Normal,F<:Real} - dists::Vector{D} - likeadj::F -end - -function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) - μs = randn(rng, n_normals) - σs = ones(n_normals) - dists = Normal.(μs, σs) - SubsampledNormals{eltype(dists),Float64}(dists, 1.0) -end - -function LogDensityProblems.logdensity(m::SubsampledNormals, x) - (; likeadj, dists) = m - likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists) -end - -function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x) - return ( - LogDensityProblems.logdensity(m, x), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), - ) -end - -function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) - return LogDensityProblems.LogDensityOrder{1}() -end - -function AdvancedVI.subsample(m::SubsampledNormals, idx) - n_data = length(m.dists) - SubsampledNormals(m.dists[idx], n_data/length(idx)) -end - @testset "SubsampledObjective" begin seed = (0x38bef07cf9cc549d) n_data = 8 - prob = SubsampledNormals(Random.default_rng(), n_data) - μ0 = [mean([mean(dist) for dist in prob.dists])] - q0 = MeanFieldGaussian(μ0, Diagonal(ones(1))) + modelstats = subsamplednormal(n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + q0 = MeanFieldGaussian(μ_true, Diagonal(diag(L_true))) full_obj = RepGradELBO(10) @testset "algorithm constructors" begin @@ -47,17 +15,17 @@ end alg = KLMinRepGradDescent( AD; n_samples=10, subsampling=sub, operator=ClipScale() ) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) alg = KLMinRepGradProxDescent(AD; n_samples=10, subsampling=sub) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) alg = KLMinScoreGradDescent( AD; n_samples=100, subsampling=sub, operator=ClipScale() ) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) end end @@ -69,25 +37,25 @@ end sub_obj = alg.objective rng = StableRNG(seed) - q_avg, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) rng = StableRNG(seed) - q_avg_ref, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false) + q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) @test q_avg == q_avg_ref rng = StableRNG(seed) - sub_objval_ref = estimate_objective(rng, sub_obj, q0, prob) + sub_objval_ref = estimate_objective(rng, sub_obj, q0, model) rng = StableRNG(seed) - sub_objval = estimate_objective(rng, sub_obj, q0, prob) + sub_objval = estimate_objective(rng, sub_obj, q0, model) @test sub_objval == sub_objval_ref end @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] sub = ReshufflingBatchSubsampling(1:n_data, batchsize) sub_obj′ = SubsampledObjective(full_obj, sub) - full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^8) - sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^8) + full_objval = estimate_objective(full_obj, q0, model; n_samples=10^8) + sub_objval = estimate_objective(sub_obj′, q0, model; n_samples=10^8) @test full_objval ≈ sub_objval rtol=0.1 end @@ -100,7 +68,7 @@ end # Estimate using full batch rng = StableRNG(seed) - full_state = AdvancedVI.init(rng, full_obj, AD, q0, prob, params, restructure) + full_state = AdvancedVI.init(rng, full_obj, AD, q0, model, params, restructure) AdvancedVI.estimate_gradient!( rng, full_obj, AD, out, full_state, params, restructure ) @@ -108,7 +76,7 @@ end # Estimate the full batch gradient by averaging the minibatch gradients rng = StableRNG(seed) - sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, prob, params, restructure) + sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, model, params, restructure) grad = mean(1:length(sub_obj.subsampling)) do _ # Fixing the RNG so that the same Monte Carlo samples are used across the batches rng = StableRNG(seed) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl new file mode 100644 index 000000000..031a6de38 --- /dev/null +++ b/test/models/subsamplednormals.jl @@ -0,0 +1,50 @@ + +struct SubsampledNormals{D<:Normal,F<:Real,C} + dists::Vector{D} + likeadj::F + cap::C +end + +function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) + μs = randn(rng, n_normals) + σs = ones(n_normals) + dists = Normal.(μs, σs) + return SubsampledNormals{eltype(dists),Float64}(dists, 1.0) +end + +function LogDensityProblems.logdensity(m::SubsampledNormals, x) + (; likeadj, dists) = m + return likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists) +end + +function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x) + return ( + LogDensityProblems.logdensity(m, x), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), + ) +end + +function LogDensityProblems.capabilities(::Type{SubsampledNormals{D,F,C}}) where {D,F,C} + return C() +end + +function AdvancedVI.subsample(m::SubsampledNormals, idx) + n_data = length(m.dists) + return SubsampledNormals(m.dists[idx], n_data/length(idx)) +end + +function subsamplednormal(n_data::Int; capability::Int=1) + cap = if capability == 1 + LogDensityProblems.LogDensityOrder{1}() + elseif capability == 2 + LogDensityProblems.LogDensityOrder{2}() + else + LogDensityProblems.LogDensityOrder{0}() + end + model = SubsampledNormals(Random.default_rng(), n_data, cap) + + n_dims = 1 + μ_true = [mean([mean(dist) for dist in prob.dists])] + L_true = Diagonal(ones(1)) + return TestModel(model, μ_true, L_true, n_dims, 1, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2dae5b31c..47bf0846b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,9 +53,9 @@ struct TestModel{M,L,S,SC} end include("models/normal.jl") include("models/normallognormal.jl") +include("models/subsamplednormals.jl") if GROUP == "All" || GROUP == "GENERAL" - # Tests that do not need to check correct integration with AD backends include("general/optimize.jl") include("general/proximal_location_scale_entropy.jl") include("general/rules.jl") From 563bb5d8669a37bed89dd1555f26490efe88b07b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 06:55:29 -0400 Subject: [PATCH 03/28] fix wrong capability for subsamplednormals --- test/models/subsamplednormals.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl index 031a6de38..c187f1ae4 100644 --- a/test/models/subsamplednormals.jl +++ b/test/models/subsamplednormals.jl @@ -24,8 +24,16 @@ function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x) ) end -function LogDensityProblems.capabilities(::Type{SubsampledNormals{D,F,C}}) where {D,F,C} - return C() +function LogDensityProblems.logdensity_gradient_and_hessian(m::SubsampledNormals, x) + return ( + LogDensityProblems.logdensity(m, x), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), + ForwardDiff.hessian(Base.Fix1(LogDensityProblems.logdensity, m), x), + ) +end + +function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) + return LogDensityProblems.LogDensityOrder{2}() end function AdvancedVI.subsample(m::SubsampledNormals, idx) @@ -33,16 +41,8 @@ function AdvancedVI.subsample(m::SubsampledNormals, idx) return SubsampledNormals(m.dists[idx], n_data/length(idx)) end -function subsamplednormal(n_data::Int; capability::Int=1) - cap = if capability == 1 - LogDensityProblems.LogDensityOrder{1}() - elseif capability == 2 - LogDensityProblems.LogDensityOrder{2}() - else - LogDensityProblems.LogDensityOrder{0}() - end - model = SubsampledNormals(Random.default_rng(), n_data, cap) - +function subsamplednormal(n_data::Int) + model = SubsampledNormals(Random.default_rng(), n_data) n_dims = 1 μ_true = [mean([mean(dist) for dist in prob.dists])] L_true = Diagonal(ones(1)) From 6674b427d94b3814b7bf059b20df7079d01941dc Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 07:00:30 -0400 Subject: [PATCH 04/28] fix remove unused fields --- test/models/subsamplednormals.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl index c187f1ae4..e1da9d635 100644 --- a/test/models/subsamplednormals.jl +++ b/test/models/subsamplednormals.jl @@ -1,8 +1,7 @@ -struct SubsampledNormals{D<:Normal,F<:Real,C} +struct SubsampledNormals{D<:Normal,F<:Real} dists::Vector{D} likeadj::F - cap::C end function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) From bd0b88950567223664bbb2815a653324b84928a6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 07:05:22 -0400 Subject: [PATCH 05/28] fix wrong variable name, add missing rng argument --- test/algorithms/subsampledobj.jl | 2 +- test/models/subsamplednormals.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/algorithms/subsampledobj.jl b/test/algorithms/subsampledobj.jl index 9c221a9fb..c2907e62c 100644 --- a/test/algorithms/subsampledobj.jl +++ b/test/algorithms/subsampledobj.jl @@ -3,7 +3,7 @@ seed = (0x38bef07cf9cc549d) n_data = 8 - modelstats = subsamplednormal(n_data) + modelstats = subsamplednormal(Random.default_rng(), n_data) (; model, n_dims, μ_true, L_true) = modelstats q0 = MeanFieldGaussian(μ_true, Diagonal(diag(L_true))) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl index e1da9d635..86be0dc1a 100644 --- a/test/models/subsamplednormals.jl +++ b/test/models/subsamplednormals.jl @@ -40,10 +40,10 @@ function AdvancedVI.subsample(m::SubsampledNormals, idx) return SubsampledNormals(m.dists[idx], n_data/length(idx)) end -function subsamplednormal(n_data::Int) - model = SubsampledNormals(Random.default_rng(), n_data) +function subsamplednormal(rng::Random.AbstractRNG, n_data::Int) + model = SubsampledNormals(rng, n_data) n_dims = 1 - μ_true = [mean([mean(dist) for dist in prob.dists])] + μ_true = [mean([mean(dist) for dist in model.dists])] L_true = Diagonal(ones(1)) return TestModel(model, μ_true, L_true, n_dims, 1, true) end From bb2fa3a106760751999632969e407eb92fb0d543 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 09:43:07 -0400 Subject: [PATCH 06/28] fix missing `dimension` and wrong variance in subsampled normals --- test/models/subsamplednormals.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl index 86be0dc1a..6b185e332 100644 --- a/test/models/subsamplednormals.jl +++ b/test/models/subsamplednormals.jl @@ -31,6 +31,10 @@ function LogDensityProblems.logdensity_gradient_and_hessian(m::SubsampledNormals ) end +function LogDensityProblems.dimension(::SubsampledNormals) + return 1 +end + function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) return LogDensityProblems.LogDensityOrder{2}() end @@ -44,6 +48,6 @@ function subsamplednormal(rng::Random.AbstractRNG, n_data::Int) model = SubsampledNormals(rng, n_data) n_dims = 1 μ_true = [mean([mean(dist) for dist in model.dists])] - L_true = Diagonal(ones(1)) + L_true = Diagonal([sqrt(1/n_data)]) return TestModel(model, μ_true, L_true, n_dims, 1, true) end From 7b6246fe1b7dd75dd2b42751024dd6c380d91a82 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 09:58:24 -0400 Subject: [PATCH 07/28] add Wasserstein VI algorithm --- src/AdvancedVI.jl | 6 ++ src/algorithms/klminwassfwdbwd.jl | 167 ++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+) create mode 100644 src/algorithms/klminwassfwdbwd.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8fb85d7ca..2b82b7b79 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -352,4 +352,10 @@ include("algorithms/common.jl") export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI +# Other Algorithms + +include("algorithms/klminwassfwdbwd.jl") + +export KLMinWassFwdBwd + end diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl new file mode 100644 index 000000000..33b044e10 --- /dev/null +++ b/src/algorithms/klminwassfwdbwd.jl @@ -0,0 +1,167 @@ + +""" + KLMinWassFwdBwd(n_samples, stepsize, subsampling) + KLMinWassFwdBwd(; n_samples, stepsize, subsampling) + +KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. + +# (Keyword) Arguments +- `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) +- `stepsize::Float64`: Step size of stochastic proximal gradient descent. +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinWassFwdBwd` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback +The callback function `callback` has a signature of + + callback(; rng, iteration, q, info) + +The arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target `LogDensityProblems.logdensity(prob, x)` has second-order differentiation capability. (`KLMinWassFwdBwd` uses Hessians of the log-density.) +""" +@kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + n_samples::Int = 1 + stepsize::Float64 + subsampling::Sub = nothing +end + +struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} + q::Q + prob::P + sigma::Sigma + iteration::Int + sub_st::S + grad_buf::GradBuf + hess_buf::HessBuf +end + +function init( + rng::Random.AbstractRNG, + alg::KLMinWassFwdBwd, + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + sub = alg.subsampling + n_dims = LogDensityProblems.dimension(prob) + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{2}() + throw( + ArgumentError( + "`KLMinWassFwdBwd` requires second-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(sub) ? nothing : init(rng, sub) + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + return KLMinWassFwdBwdState(q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf) +end + +output(::KLMinWassFwdBwd, state) = state.q + +function step( + rng::Random.AbstractRNG, alg::KLMinWassFwdBwd, state, callback, objargs...; kwargs... +) + (; n_samples, stepsize, subsampling) = alg + (; q, prob, sigma, iteration, sub_st, grad_buf, hess_buf) = state + + m = mean(q) + Σ = sigma + η = convert(eltype(m), stepsize) + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + # Estimate the moments required for computing the Wasserstein gradient + z = rand(rng, q, n_samples) + V_avg = 0 + fill!(grad_buf, zero(eltype(grad_buf))) + fill!(hess_buf, zero(eltype(hess_buf))) + for b in 1:n_samples + negVb, neg∇Vb, neg∇2Vb = LogDensityProblems.logdensity_gradient_and_hessian( + prob_sub, z[:, b] + ) + V_avg += -negVb/n_samples + grad_buf += -neg∇Vb/n_samples + hess_buf += -neg∇2Vb/n_samples + end + + m′ = m - η*grad_buf + M = I - η*Hermitian(hess_buf) + Σ_half = Hermitian(M*Σ*M) + + # Compute the JKO proximal operator + Σ′ = (Σ_half + 2*η*I + sqrt(Hermitian(Σ_half*(Σ_half + 4*η*I))))/2 + q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist) + + state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf) + elbo = -V_avg + entropy(q′) + info = merge((elbo=elbo,), sub_inf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, info) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::KLMinWassFwdBwd`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::KLMinWassFwdBwd, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + if isnothing(alg.subsampling) + return estimate_objective( + rng, obj, q, prob + ) + else + sub = alg.subsampling + sub_st = init(rng, sub) + return mapreduce(+, 1:length(sub)) do _ + batch, sub_st, _ = step(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + estimate_objective(rng, obj, q, prob_sub) / length(sub) + end + end +end From a3201f8c38621f0924f04ddbba42db49b7a4eda6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 09:59:02 -0400 Subject: [PATCH 08/28] add tests for Wasserstein VI --- test/algorithms/klminwassfwdbwd.jl | 144 +++++++++++++++++++++++++++++ test/models/normal.jl | 37 ++++++-- test/runtests.jl | 2 + 3 files changed, 176 insertions(+), 7 deletions(-) create mode 100644 test/algorithms/klminwassfwdbwd.jl diff --git a/test/algorithms/klminwassfwdbwd.jl b/test/algorithms/klminwassfwdbwd.jl new file mode 100644 index 000000000..147cc10df --- /dev/null +++ b/test/algorithms/klminwassfwdbwd.jl @@ -0,0 +1,144 @@ + +@testset "KLMinWassFwdBwd" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + begin + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1.0) + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64) + (; model, n_dims) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "second-order" optimize(alg, 1, model, q0) + end + end + + @testset "type stability $(realtype)" for realtype in [Float64, Float32] + modelstats = normal_meanfield(Random.default_rng(), realtype; capability=2) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + + @testset "subsampling" begin + n_data = 8 + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) + alg_sub = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3, subsampling) + + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) + @test obj_full ≈ obj_sub rtol=0.1 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3, subsampling) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence" begin + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-2, subsampling) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + end +end diff --git a/test/models/normal.jl b/test/models/normal.jl index 3db1af08d..a8de0f336 100644 --- a/test/models/normal.jl +++ b/test/models/normal.jl @@ -1,7 +1,8 @@ -struct TestNormal{M,S} +struct TestNormal{M,S,C} μ::M Σ::S + cap::C end function LogDensityProblems.logdensity(model::TestNormal, θ) @@ -16,15 +17,23 @@ function LogDensityProblems.logdensity_and_gradient(model::TestNormal, θ) ) end +function LogDensityProblems.logdensity_gradient_and_hessian(model::TestNormal, θ) + return ( + LogDensityProblems.logdensity(model, θ), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ForwardDiff.hessian(Base.Fix1(LogDensityProblems.logdensity, model), θ), + ) +end + function LogDensityProblems.dimension(model::TestNormal) return length(model.μ) end -function LogDensityProblems.capabilities(::Type{<:TestNormal}) - return LogDensityProblems.LogDensityOrder{1}() +function LogDensityProblems.capabilities(::Type{TestNormal{M,S,C}}) where {M,S,C} + return C() end -function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) +function normal_fullrank(rng::Random.AbstractRNG, realtype::Type; capability::Int=1) n_dims = 5 σ0 = realtype(0.3) @@ -32,19 +41,33 @@ function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) L = Matrix(σ0 * I, n_dims, n_dims) Σ = Hermitian(L * L') - model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) + cap = if capability == 1 + LogDensityProblems.LogDensityOrder{1}() + elseif capability == 2 + LogDensityProblems.LogDensityOrder{2}() + else + LogDensityProblems.LogDensityOrder{0}() + end + model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)), cap) return TestModel(model, μ, LowerTriangular(L), n_dims, 1 / σ0^2, false) end -function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) +function normal_meanfield(rng::Random.AbstractRNG, realtype::Type; capability::Int=1) n_dims = 5 σ0 = realtype(0.3) μ = Fill(realtype(5), n_dims) σ = Fill(σ0, n_dims) - model = TestNormal(μ, Diagonal(σ .^ 2)) + cap = if capability == 1 + LogDensityProblems.LogDensityOrder{1}() + elseif capability == 2 + LogDensityProblems.LogDensityOrder{2}() + else + LogDensityProblems.LogDensityOrder{0}() + end + model = TestNormal(μ, Diagonal(σ .^ 2), cap) L = Diagonal(σ) diff --git a/test/runtests.jl b/test/runtests.jl index 47bf0846b..ab67247b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,6 +64,8 @@ if GROUP == "All" || GROUP == "GENERAL" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") + + include("algorithms/klminwassfwdbwd.jl") end if GROUP == "All" || GROUP == "AD" From f478e15da50d0071e72e0d66af3acd1e493d80b6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 09:59:39 -0400 Subject: [PATCH 09/28] add docs for Wasserstein VI --- HISTORY.md | 8 ++++ docs/make.jl | 7 ++-- docs/src/index.md | 1 + docs/src/klminwassfwdbwd.md | 79 +++++++++++++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 docs/src/klminwassfwdbwd.md diff --git a/HISTORY.md b/HISTORY.md index 440f765d4..bb0a34425 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,3 +1,11 @@ + +# Release 0.6 + +This update adds new variational inference algorithms in light of the flexibility added in the v0.5 update. +Specifically, the following measure-space optimization algorithms have been added: + +- `KLMinWassFwdBwd` + # Release 0.5 ## Default Configuration Changes diff --git a/docs/make.jl b/docs/make.jl index 3c1d3a2d7..2cacf5fcb 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -23,9 +23,10 @@ makedocs(; "Normalizing Flows" => "tutorials/flows.md", ], "Algorithms" => [ - "KLMinRepGradDescent" => "klminrepgraddescent.md", - "KLMinRepGradProxDescent" => "klminrepgradproxdescent.md", - "KLMinScoreGradDescent" => "klminscoregraddescent.md", + "`KLMinRepGradDescent`" => "klminrepgraddescent.md", + "`KLMinRepGradProxDescent`" => "klminrepgradproxdescent.md", + "`KLMinScoreGradDescent`" => "klminscoregraddescent.md", + "`KLMinWassFwdBwd`" => "klminwassfwdbwd.md", ], "Variational Families" => "families.md", "Optimization" => "optimization.md", diff --git a/docs/src/index.md b/docs/src/index.md index f82087b00..0abd8c05f 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -17,3 +17,4 @@ For using the algorithms implemented in `AdvancedVI`, refer to the corresponding - [KLMinRepGradDescent](@ref klminrepgraddescent) (alias of `ADVI`) - [KLMinRepGradProxDescent](@ref klminrepgradproxdescent) - [KLMinScoreGradDescent](@ref klminscoregraddescent) (alias of `BBVI`) + - [KLMinWassFwdBwd](@ref klminwassfwdbwd) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md new file mode 100644 index 000000000..6a88781eb --- /dev/null +++ b/docs/src/klminwassfwdbwd.md @@ -0,0 +1,79 @@ +# [`KLMinWassFwdBwd`](@id klminwassfwdbwd) + +## Description + +This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running proximal gradient descent (also known as forward-backward splitting) in Wasserstein space[^DBCS2023]. +(This algorithm is also sometimes referred to as "Wasserstein VI".) +Since `KLMinWassFwdBwd` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that makes the measure-valued operations tractable. +Furthermore, it requires Hessians of the target log-density. + +```@docs +KLMinWassFwdBwd +``` + +The associated objective value, which is the ELBO, can be estimated through the following: + +```@docs; canonical=false +estimate_objective( + ::Random.AbstractRNG, + ::KLMinWassFwdBwd, + ::MvLocationScale, + ::Any; + ::Int, +) +``` + +[^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In *International Conference on Machine Learning*. PMLR. +## [Methodology](@id klminwassfwdbwd_method) + +This algorithm aims to solve the problem + +```math + \mathrm{minimize}_{q \in \mathcal{Q}}\quad \mathrm{KL}\left(q, \pi\right) +``` + +where $\mathcal{Q}$ is some family of distributions, often called the variational family. +Since we usually only have access to the unnormalized densities of the target distribution $\pi$, we don't have direct access to the KL divergence. +Instead, we focus on minimizing a surrogate objective, the *free energy functional*, which corresponds to the negated evidence lower bound[^JGJS1999], defined as + +```math + \mathcal{F}\left(q\right) \triangleq \mathcal{U}\left(q\right) + \mathcal{H}\left(q\right), +``` + +where + +```math +\begin{aligned} + \mathcal{U}\left(q\right) &= \mathbb{E}_{\theta \sim q} -\log \pi\left(\theta\right) + &&\text{(``potential energy'')} + \\ + \mathcal{H}\left(q\right) &= \mathbb{E}_{\theta \sim q} \log q\left(\theta\right) . + &&\text{(``Boltzmann entropy'')} +\end{aligned} +``` + +For solving this problem, `KLMinWassFwdBwd` relies on proximal stochastic gradient descent (PSGD)---also known as "forward-backward splitting"---that iterates + +```math + q_{t+1} = \mathrm{JKO}_{\gamma_t \mathcal{H}}\big( + q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{V}} (q_{t}) + \big) , +``` + +where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{V}}$$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $$\mathcal{V}$$, the JKO (proximal) operator is defined as + +```math +\mathrm{JKO}_{\gamma_t \mathcal{H}}(\mu) += +\argmin_{\nu \in \mathcal{Q}} \left\{ \mathcal{H}(\nu) + \frac{1}{2 \gamma_t} \mathrm{W}_2 {(\mu, \nu)}^2 \right\} , +``` + +and $$\mathrm{W}_2$$ is the Wasserstein-2 distance. +When $$\mathcal{Q}$$ is set to be the Bures-Wasserstein space of $$\mathbb{R}^d$$, this algorithm is referred to as the Jordan-Kinderlehrer-Otto (JKO) scheme[^JKO1998], which was originally developed to study gradient flows under Wasserstein metrics. +Within this context, `KLMinWassFwdBwd` can be viewed as a numerical realization of the JKO scheme. +This also exactly corresponds to the measure-space analog of [KLMinRepGradProxDescent](@ref klminrepgradproxdescent). +Similarly to `KLMinRepGradProxDescent`, the JKO proximal operator to be tractable, which has to be derived for different variational families on a case-by-case basis. +Diao *et al.*[^DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by `KLMinWassFwdBwd`. + +[^JKO1998]: Jordan, R., Kinderlehrer, D., & Otto, F. (1998). The variational formulation of the Fokker--Planck equation. *SIAM Journal on Mathematical Analysis*, 29(1). +[^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. From 9583ef95f58a82db5a9dc5fb8a133f26aa57314e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 10:05:55 -0400 Subject: [PATCH 10/28] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- HISTORY.md | 1 - 1 file changed, 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index bb0a34425..b9e77cc50 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,4 +1,3 @@ - # Release 0.6 This update adds new variational inference algorithms in light of the flexibility added in the v0.5 update. From ba31866fad1dd31cff802f676fc6272bcbc23634 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 10:06:01 -0400 Subject: [PATCH 11/28] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index b9e77cc50..067c6bb02 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -3,7 +3,7 @@ This update adds new variational inference algorithms in light of the flexibility added in the v0.5 update. Specifically, the following measure-space optimization algorithms have been added: -- `KLMinWassFwdBwd` + - `KLMinWassFwdBwd` # Release 0.5 From 455118fdaaa41ffcc83417632c9c9053ec65762c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 25 Oct 2025 10:06:08 -0400 Subject: [PATCH 12/28] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/klminwassfwdbwd.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 33b044e10..bef6eb2bc 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -152,9 +152,7 @@ function estimate_objective( ) where {S,L} obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) if isnothing(alg.subsampling) - return estimate_objective( - rng, obj, q, prob - ) + return estimate_objective(rng, obj, q, prob) else sub = alg.subsampling sub_st = init(rng, sub) From 83215558f052aa99c1d780d676dbfc6963dda1a6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 26 Oct 2025 01:19:05 -0400 Subject: [PATCH 13/28] fix typos --- docs/src/klminwassfwdbwd.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md index 6a88781eb..a49808910 100644 --- a/docs/src/klminwassfwdbwd.md +++ b/docs/src/klminwassfwdbwd.md @@ -56,11 +56,11 @@ For solving this problem, `KLMinWassFwdBwd` relies on proximal stochastic gradie ```math q_{t+1} = \mathrm{JKO}_{\gamma_t \mathcal{H}}\big( - q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{V}} (q_{t}) + q_{t} - \gamma_t \widehat{\nabla_{\mathrm{BW}} \mathcal{U}} (q_{t}) \big) , ``` -where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{V}}$$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $$\mathcal{V}$$, the JKO (proximal) operator is defined as +where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{U}}$$ is a stochastic estimate of the Bures-Wasserstein measure-valued gradient of $$\mathcal{U}$$, the JKO (proximal) operator is defined as ```math \mathrm{JKO}_{\gamma_t \mathcal{H}}(\mu) From c985932d6448531074393edb890a46b6baebaadd Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 08:28:10 -0400 Subject: [PATCH 14/28] add optional Stein's identity expected Hessian estimator --- src/algorithms/klminwassfwdbwd.jl | 83 +++++++++++++++++++++++------- test/algorithms/klminwassfwdbwd.jl | 16 +++--- test/models/subsamplednormals.jl | 30 +++++++---- 3 files changed, 94 insertions(+), 35 deletions(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index bef6eb2bc..5960cf3f3 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,6 +5,8 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. +Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. It the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. + # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) - `stepsize::Float64`: Step size of stochastic proximal gradient descent. @@ -30,7 +32,7 @@ The arguments are as follows: # Requirements - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). - The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). -- The target `LogDensityProblems.logdensity(prob, x)` has second-order differentiation capability. (`KLMinWassFwdBwd` uses Hessians of the log-density.) +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. """ @kwdef struct KLMinWassFwdBwd{Sub<:Union{Nothing,<:AbstractSubsampling}} <: AbstractVariationalAlgorithm @@ -39,6 +41,58 @@ The arguments are as follows: subsampling::Sub = nothing end +""" + gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) + +Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function internally uses the Bonnet-Price estimator. Otherwise, it uses Stein's identity. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. +- `n_samples::Int`: Number of samples used for estimation. +- `grad_buf::AbstractVector`: Buffer for the gradient estimate. +- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. +- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. +""" +function gaussian_expectation_gradient_and_hessian!( + rng::Random.AbstractRNG, + q::MvLocationScale{<:LowerTriangular,<:Normal,L}, + n_samples::Int, + grad_buf::AbstractVector{T}, + hess_buf::AbstractMatrix{T}, + prob +) where {T<:Real,L} + logπ_avg = zero(T) + fill!(grad_buf, zero(T)) + fill!(hess_buf, zero(T)) + + if LogDensityProblems.capabilities(typeof(prob)) ≤ LogDensityProblems.LogDensityOrder{1}() + # Use Stein's identity + d = LogDensityProblems.dimension(prob) + u = randn(rng, T, d, n_samples) + z = q.scale*u .+ q.location + for b in 1:n_samples + zb, ub = view(z, :, b), view(u, :, b) + logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) + logπ_avg += logπ/n_samples + grad_buf += ∇logπ/n_samples + hess_buf += ub*(∇logπ/n_samples)' + end + return logπ_avg, grad_buf, hess_buf + else + # Use sample average of the Hessian. + z = rand(rng, q, n_samples) + for b in 1:n_samples + zb = view(z, :, b) + logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(prob, zb) + logπ_avg += logπ/n_samples + grad_buf += ∇logπ/n_samples + hess_buf += ∇2logπ/n_samples + end + return logπ_avg, grad_buf, hess_buf + end +end + struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} q::Q prob::P @@ -58,10 +112,10 @@ function init( sub = alg.subsampling n_dims = LogDensityProblems.dimension(prob) capability = LogDensityProblems.capabilities(typeof(prob)) - if capability < LogDensityProblems.LogDensityOrder{2}() + if capability < LogDensityProblems.LogDensityOrder{1}() throw( ArgumentError( - "`KLMinWassFwdBwd` requires second-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + "`KLMinWassFwdBwd` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", ), ) end @@ -93,22 +147,13 @@ function step( prob_sub, sub_st′, sub_inf end - # Estimate the moments required for computing the Wasserstein gradient - z = rand(rng, q, n_samples) - V_avg = 0 - fill!(grad_buf, zero(eltype(grad_buf))) - fill!(hess_buf, zero(eltype(hess_buf))) - for b in 1:n_samples - negVb, neg∇Vb, neg∇2Vb = LogDensityProblems.logdensity_gradient_and_hessian( - prob_sub, z[:, b] - ) - V_avg += -negVb/n_samples - grad_buf += -neg∇Vb/n_samples - hess_buf += -neg∇2Vb/n_samples - end + # Estimate the Wasserstein gradient + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( + rng, q, n_samples, grad_buf, hess_buf, prob_sub + ) - m′ = m - η*grad_buf - M = I - η*Hermitian(hess_buf) + m′ = m - η*-grad_buf + M = I - η*Hermitian(-hess_buf) Σ_half = Hermitian(M*Σ*M) # Compute the JKO proximal operator @@ -116,7 +161,7 @@ function step( q′ = MvLocationScale(m′, cholesky(Σ′).L, q.dist) state = KLMinWassFwdBwdState(q′, prob, Σ′, iteration, sub_st′, grad_buf, hess_buf) - elbo = -V_avg + entropy(q′) + elbo = logπ_avg + entropy(q′) info = merge((elbo=elbo,), sub_inf) if !isnothing(callback) diff --git a/test/algorithms/klminwassfwdbwd.jl b/test/algorithms/klminwassfwdbwd.jl index 147cc10df..8261a7dd5 100644 --- a/test/algorithms/klminwassfwdbwd.jl +++ b/test/algorithms/klminwassfwdbwd.jl @@ -47,17 +47,21 @@ alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1.0) @testset "error low capability" begin - modelstats = normal_meanfield(Random.default_rng(), Float64) + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) (; model, n_dims) = modelstats L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) q0 = FullRankGaussian(zeros(Float64, n_dims), L0) - @test_throws "second-order" optimize(alg, 1, model, q0) + @test_throws "first-order" optimize(alg, 1, model, q0) end end - @testset "type stability $(realtype)" for realtype in [Float64, Float32] - modelstats = normal_meanfield(Random.default_rng(), realtype; capability=2) + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) @@ -72,8 +76,8 @@ @test eltype(q.scale) == eltype(L_true) end - @testset "convergence" begin - modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = normal_meanfield(Random.default_rng(), Float64; capability) (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats T = 1000 diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl index 6b185e332..a254bbf24 100644 --- a/test/models/subsamplednormals.jl +++ b/test/models/subsamplednormals.jl @@ -1,14 +1,17 @@ -struct SubsampledNormals{D<:Normal,F<:Real} +struct SubsampledNormals{D<:Normal,F<:Real,C} dists::Vector{D} likeadj::F + capability::C end -function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) +function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int, capability) μs = randn(rng, n_normals) σs = ones(n_normals) dists = Normal.(μs, σs) - return SubsampledNormals{eltype(dists),Float64}(dists, 1.0) + return SubsampledNormals{eltype(dists),Float64,typeof(capability)}( + dists, 1.0, capability + ) end function LogDensityProblems.logdensity(m::SubsampledNormals, x) @@ -35,17 +38,24 @@ function LogDensityProblems.dimension(::SubsampledNormals) return 1 end -function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) - return LogDensityProblems.LogDensityOrder{2}() +function LogDensityProblems.capabilities(::Type{SubsampledNormals{D,F,C}}) where {D,F,C} + return C() end function AdvancedVI.subsample(m::SubsampledNormals, idx) n_data = length(m.dists) - return SubsampledNormals(m.dists[idx], n_data/length(idx)) -end - -function subsamplednormal(rng::Random.AbstractRNG, n_data::Int) - model = SubsampledNormals(rng, n_data) + return SubsampledNormals(m.dists[idx], n_data/length(idx), m.capability) +end + +function subsamplednormal(rng::Random.AbstractRNG, n_data::Int; capability::Int=1) + cap = if capability == 1 + LogDensityProblems.LogDensityOrder{1}() + elseif capability == 2 + LogDensityProblems.LogDensityOrder{2}() + else + LogDensityProblems.LogDensityOrder{0}() + end + model = SubsampledNormals(rng, n_data, cap) n_dims = 1 μ_true = [mean([mean(dist) for dist in model.dists])] L_true = Diagonal([sqrt(1/n_data)]) From 17c1069c6f8582fc4577170c11e1d75d70aa3c00 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 08:31:44 -0400 Subject: [PATCH 15/28] fix test also capability in subsampling convergence test --- test/algorithms/klminwassfwdbwd.jl | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/test/algorithms/klminwassfwdbwd.jl b/test/algorithms/klminwassfwdbwd.jl index 8261a7dd5..67f493c7e 100644 --- a/test/algorithms/klminwassfwdbwd.jl +++ b/test/algorithms/klminwassfwdbwd.jl @@ -94,13 +94,13 @@ @testset "subsampling" begin n_data = 8 - modelstats = subsamplednormal(Random.default_rng(), n_data) - (; model, n_dims, μ_true, L_true) = modelstats + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats - L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) - q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) - @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3) alg_sub = KLMinWassFwdBwd(; n_samples=10, stepsize=1e-3, subsampling) @@ -114,6 +114,12 @@ seed = (0x38bef07cf9cc549d) rng = StableRNG(seed) + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + T = 10 batchsize = 3 subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) @@ -131,7 +137,13 @@ @test L == L_repl end - @testset "convergence" begin + @testset "convergence capability=$(capability)" for capability in [1,2] + modelstats = subsamplednormal(Random.default_rng(), n_data; capability) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + T = 1000 batchsize = 1 subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) From 283961f6bc296b9a8034d46af46bf7bfcda9f2ba Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 08:32:35 -0400 Subject: [PATCH 16/28] update docs remove comment that hessian is required for fwdbwdwass --- docs/src/klminwassfwdbwd.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md index a49808910..54198e196 100644 --- a/docs/src/klminwassfwdbwd.md +++ b/docs/src/klminwassfwdbwd.md @@ -5,7 +5,6 @@ This algorithm aims to minimize the exclusive (or reverse) Kullback-Leibler (KL) divergence by running proximal gradient descent (also known as forward-backward splitting) in Wasserstein space[^DBCS2023]. (This algorithm is also sometimes referred to as "Wasserstein VI".) Since `KLMinWassFwdBwd` is a measure-space algorithm, its use is restricted to full-rank Gaussian variational families (`FullRankGaussian`) that makes the measure-valued operations tractable. -Furthermore, it requires Hessians of the target log-density. ```@docs KLMinWassFwdBwd From 724103cf290bd56cbb418a1bd0c209ec54ae906d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 08:39:02 -0400 Subject: [PATCH 17/28] update docs --- docs/src/klminwassfwdbwd.md | 7 +++---- src/algorithms/klminwassfwdbwd.jl | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md index 54198e196..4c1911fff 100644 --- a/docs/src/klminwassfwdbwd.md +++ b/docs/src/klminwassfwdbwd.md @@ -69,10 +69,9 @@ where $$\widehat{\nabla_{\mathrm{BW}} \mathcal{U}}$$ is a stochastic estimate of and $$\mathrm{W}_2$$ is the Wasserstein-2 distance. When $$\mathcal{Q}$$ is set to be the Bures-Wasserstein space of $$\mathbb{R}^d$$, this algorithm is referred to as the Jordan-Kinderlehrer-Otto (JKO) scheme[^JKO1998], which was originally developed to study gradient flows under Wasserstein metrics. -Within this context, `KLMinWassFwdBwd` can be viewed as a numerical realization of the JKO scheme. -This also exactly corresponds to the measure-space analog of [KLMinRepGradProxDescent](@ref klminrepgradproxdescent). -Similarly to `KLMinRepGradProxDescent`, the JKO proximal operator to be tractable, which has to be derived for different variational families on a case-by-case basis. -Diao *et al.*[^DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by `KLMinWassFwdBwd`. +Within this context, `KLMinWassFwdBwd` can be viewed as a numerical realization of the JKO scheme by restricting $$\mathcal{Q}$$ to be a tractable parametric variational family. +Specifically, Diao *et al.*[^DBCS2023] derived the JKO update for multivariate Gaussians, which is implemented by `KLMinWassFwdBwd`. +`KLMinWassFwdBwd` also exactly corresponds to the measure-space analog of [KLMinRepGradProxDescent](@ref klminrepgradproxdescent). [^JKO1998]: Jordan, R., Kinderlehrer, D., & Otto, F. (1998). The variational formulation of the Fokker--Planck equation. *SIAM Journal on Mathematical Analysis*, 29(1). [^JGJS1999]: Jordan, M. I., Ghahramani, Z., Jaakkola, T. S., & Saul, L. K. (1999). An introduction to variational methods for graphical models. Machine learning, 37, 183-233. diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 5960cf3f3..8c82f93c9 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -44,7 +44,7 @@ end """ gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) -Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function internally uses the Bonnet-Price estimator. Otherwise, it uses Stein's identity. +Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. # Arguments - `rng::Random.AbstractRNG`: Random number generator. From f8ca2bec833475fc6e068e264c3806a5e848a843 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 18:58:00 -0400 Subject: [PATCH 18/28] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/klminwassfwdbwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 8c82f93c9..618dd98fd 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -60,7 +60,7 @@ function gaussian_expectation_gradient_and_hessian!( n_samples::Int, grad_buf::AbstractVector{T}, hess_buf::AbstractMatrix{T}, - prob + prob, ) where {T<:Real,L} logπ_avg = zero(T) fill!(grad_buf, zero(T)) From b478db986966526718d3bbba79dc8add52f8ee1a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 18:58:09 -0400 Subject: [PATCH 19/28] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/klminwassfwdbwd.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 618dd98fd..a53f543de 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -66,7 +66,8 @@ function gaussian_expectation_gradient_and_hessian!( fill!(grad_buf, zero(T)) fill!(hess_buf, zero(T)) - if LogDensityProblems.capabilities(typeof(prob)) ≤ LogDensityProblems.LogDensityOrder{1}() + if LogDensityProblems.capabilities(typeof(prob)) ≤ + LogDensityProblems.LogDensityOrder{1}() # Use Stein's identity d = LogDensityProblems.dimension(prob) u = randn(rng, T, d, n_samples) From 176726d13eef3417e73a32cd5ee5a61599c72258 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 18:58:18 -0400 Subject: [PATCH 20/28] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/algorithms/klminwassfwdbwd.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index a53f543de..220a39a16 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -85,7 +85,9 @@ function gaussian_expectation_gradient_and_hessian!( z = rand(rng, q, n_samples) for b in 1:n_samples zb = view(z, :, b) - logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(prob, zb) + logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( + prob, zb + ) logπ_avg += logπ/n_samples grad_buf += ∇logπ/n_samples hess_buf += ∇2logπ/n_samples From 53522c9eda33bda492999a47669964d4120fcb0f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 29 Oct 2025 18:58:25 -0400 Subject: [PATCH 21/28] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/algorithms/klminwassfwdbwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/algorithms/klminwassfwdbwd.jl b/test/algorithms/klminwassfwdbwd.jl index 67f493c7e..a6875a750 100644 --- a/test/algorithms/klminwassfwdbwd.jl +++ b/test/algorithms/klminwassfwdbwd.jl @@ -137,7 +137,7 @@ @test L == L_repl end - @testset "convergence capability=$(capability)" for capability in [1,2] + @testset "convergence capability=$(capability)" for capability in [1, 2] modelstats = subsamplednormal(Random.default_rng(), n_data; capability) (; model, n_dims, μ_true, L_true) = modelstats From e21084cebb6a3f2a3e00dcb9d76380a8341cfe8c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:32:52 -0400 Subject: [PATCH 22/28] update docs Co-authored-by: Markus Hauru --- docs/src/klminwassfwdbwd.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md index 4c1911fff..4d36a6c21 100644 --- a/docs/src/klminwassfwdbwd.md +++ b/docs/src/klminwassfwdbwd.md @@ -23,6 +23,7 @@ estimate_objective( ``` [^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In *International Conference on Machine Learning*. PMLR. + ## [Methodology](@id klminwassfwdbwd_method) This algorithm aims to solve the problem From fd36d385991719a64d7d8fe16c9b35bc5b18d6f0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:33:17 -0400 Subject: [PATCH 23/28] update docs Co-authored-by: Markus Hauru --- src/algorithms/klminwassfwdbwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 220a39a16..5b7aed210 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,7 +5,7 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. -Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. It the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. +Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) From 5ef580578a429b94b2a1f1e8bdcd6b64ecbbba88 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:35:16 -0400 Subject: [PATCH 24/28] run formatter Co-authored-by: Markus Hauru --- src/algorithms/klminwassfwdbwd.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 5b7aed210..04ec83c43 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -155,7 +155,7 @@ function step( rng, q, n_samples, grad_buf, hess_buf, prob_sub ) - m′ = m - η*-grad_buf + m′ = m - η * (-grad_buf) M = I - η*Hermitian(-hess_buf) Σ_half = Hermitian(M*Σ*M) From ff67c2e49a916d66830238449d8b7ff0dbea92f0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:36:42 -0400 Subject: [PATCH 25/28] refactor test for `KLMinWassFwdBwd` --- test/algorithms/klminwassfwdbwd.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/test/algorithms/klminwassfwdbwd.jl b/test/algorithms/klminwassfwdbwd.jl index a6875a750..3832017b7 100644 --- a/test/algorithms/klminwassfwdbwd.jl +++ b/test/algorithms/klminwassfwdbwd.jl @@ -43,17 +43,15 @@ end end - begin - alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1.0) + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats - @testset "error low capability" begin - modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) - (; model, n_dims) = modelstats + alg = KLMinWassFwdBwd(; n_samples=10, stepsize=1.0) - L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) - q0 = FullRankGaussian(zeros(Float64, n_dims), L0) - @test_throws "first-order" optimize(alg, 1, model, q0) - end + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) end @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ From 2bd3bc8ed78172b1ed707ab428b8094ead08ef2f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:38:08 -0400 Subject: [PATCH 26/28] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/klminwassfwdbwd.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/klminwassfwdbwd.md b/docs/src/klminwassfwdbwd.md index 4d36a6c21..4c1911fff 100644 --- a/docs/src/klminwassfwdbwd.md +++ b/docs/src/klminwassfwdbwd.md @@ -23,7 +23,6 @@ estimate_objective( ``` [^DBCS2023]: Diao, M. Z., Balasubramanian, K., Chewi, S., & Salim, A. (2023). Forward-backward Gaussian variational inference via JKO in the Bures-Wasserstein space. In *International Conference on Machine Learning*. PMLR. - ## [Methodology](@id klminwassfwdbwd_method) This algorithm aims to solve the problem From 85dcaa9564cd2380d5dc7664cc61967d4cf25f59 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Thu, 30 Oct 2025 03:40:02 -0400 Subject: [PATCH 27/28] improve docstrings callback section --- src/algorithms/constructors.jl | 16 ++++++++-------- src/algorithms/klminwassfwdbwd.jl | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/algorithms/constructors.jl b/src/algorithms/constructors.jl index 32c85b228..ed2203816 100644 --- a/src/algorithms/constructors.jl +++ b/src/algorithms/constructors.jl @@ -21,12 +21,12 @@ KL divergence minimization by running stochastic gradient descent with the repar # Output - `q_averaged`: The variational approximation formed by the averaged SGD iterates. -# Callback -The callback function `callback` has a signature of +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) -The arguments are as follows: +The keyword arguments are as follows: - `rng`: Random number generator internally used by the algorithm. - `iteration`: The index of the current iteration. - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. @@ -100,12 +100,12 @@ Thus, only the entropy estimators with a "ZeroGradient" suffix are allowed. # Output - `q_averaged`: The variational approximation formed by the averaged SGD iterates. -# Callback -The callback function `callback` has a signature of +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) -The arguments are as follows: +The keyword arguments are as follows: - `rng`: Random number generator internally used by the algorithm. - `iteration`: The index of the current iteration. - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. @@ -178,11 +178,11 @@ KL divergence minimization by running stochastic gradient descent with the score - `q_averaged`: The variational approximation formed by the averaged SGD iterates. # Callback -The callback function `callback` has a signature of +The `callback` function supplied to `optimize` needs to have the following signature: callback(; rng, iteration, restructure, params, averaged_params, restructure, gradient) -The arguments are as follows: +The keyword arguments are as follows: - `rng`: Random number generator internally used by the algorithm. - `iteration`: The index of the current iteration. - `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(params)` reconstructs the current variational approximation. diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index 04ec83c43..f834b539a 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -18,12 +18,12 @@ Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variatio # Output - `q`: The last iterate of the algorithm. -# Callback -The callback function `callback` has a signature of +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: callback(; rng, iteration, q, info) -The arguments are as follows: +The keyword arguments are as follows: - `rng`: Random number generator internally used by the algorithm. - `iteration`: The index of the current iteration. - `q`: Current variational approximation. From 4b3487598b0660e5a6d2a27a6e078c2d65cbc309 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 1 Nov 2025 04:30:32 -0400 Subject: [PATCH 28/28] bump patch version --- HISTORY.md | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index 067c6bb02..1c4e19251 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,4 +1,4 @@ -# Release 0.6 +# Release 0.5.1 This update adds new variational inference algorithms in light of the flexibility added in the v0.5 update. Specifically, the following measure-space optimization algorithms have been added: diff --git a/Project.toml b/Project.toml index 5d31be04b..a7ff8a27a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.5.0" +version = "0.5.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"