From 51e575062ee45d78e5ea905016e3afacee7fd749 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:44:29 -0400 Subject: [PATCH 01/21] add proximal operator for the entropy of location-scale families --- ext/AdvancedVIBijectorsExt.jl | 22 ++++++ src/AdvancedVI.jl | 19 +++-- src/objectives/elbo/entropy.jl | 71 ++++++++++++++++--- src/objectives/elbo/repgradelbo.jl | 9 +-- src/optimization/clip_scale.jl | 18 ++++- src/optimize.jl | 2 +- test/inference/repgradelbo_distributionsad.jl | 1 + test/interface/clip_scale.jl | 10 ++- test/runtests.jl | 5 +- 9 files changed, 128 insertions(+), 29 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 8f6205167..7433a190a 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -9,6 +9,7 @@ using Random function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + state, params, restructure, ) @@ -27,6 +28,7 @@ end function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}}, + state, params, restructure, ) @@ -40,6 +42,26 @@ function AdvancedVI.apply( return params end +function AdvancedVI.apply( + ::ProximalLocationScaleEntropy, + ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S}, + params, + restructure, +) where {S} + q = restructure(params) + + stepsize = AdvancedVI.stepsize_from_optimizer_state(leaf.rule, leaf.state) + diag_idx = diagind(q.dist.scale) + scale_diag = q.dist.scale[diag_idx] + @. q.dist.scale[diag_idx] = + scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) + + params, _ = Optimisers.destructure(q) + + return params +end + function AdvancedVI.reparam_with_entropy( rng::Random.AbstractRNG, q::Bijectors.TransformedDistribution, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 18135cde7..16373e1d3 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -177,13 +177,14 @@ function estimate_gradient! end abstract type AbstractEntropyEstimator end """ - estimate_entropy(entropy_estimator, mc_samples, q) + estimate_entropy(entropy_estimator, mc_samples, q, q_stop) Estimate the entropy of `q`. # Arguments - `entropy_estimator`: Entropy estimation strategy. - `q`: Variational approximation. +- `q_stop`: Variational approximation with detached from the automatic differentiation graph. - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) # Returns @@ -192,7 +193,12 @@ Estimate the entropy of `q`. function estimate_entropy end export RepGradELBO, - ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy + ScoreGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy, + ClosedFormEntropyZeroGradient, + StickingTheLandingEntropyZeroGradient include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") @@ -259,7 +265,7 @@ export NoAveraging, PolynomialAveraging abstract type AbstractOperator end """ - apply(op::AbstractOperator, family, params, restructure) + apply(op::AbstractOperator, family, rule, state, params, restructure) Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator. @@ -272,7 +278,7 @@ Apply operator `op` on the variational parameters `params`. For instance, `op` c # Returns - `oped_params`: Parameters resulting from applying the operator. """ -function apply(::AbstractOperator, ::Type, ::Any, ::Any) end +function apply(::AbstractOperator, ::Type, ::Optimisers.AbstractRule, ::Any, ::Any, ::Any) end """ IdentityOperator() @@ -281,11 +287,12 @@ Identity operator. """ struct IdentityOperator <: AbstractOperator end -apply(::IdentityOperator, ::Type, params, restructure) = params +apply(::IdentityOperator, ::Type, opt_st, params, restructure) = params include("optimization/clip_scale.jl") +include("optimization/proximal_location_scale_entropy.jl") -export IdentityOperator, ClipScale +export IdentityOperator, ClipScale, ProximalLocationScaleEntropy # Main optimization routine function optimize end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 210b49ca9..b929af9d4 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,4 +1,18 @@ +""" + ClosedFormEntropyZeroGradient() + +Use closed-form expression of entropy but detach it from the AD graph. + +# Requirements +- The variational approximation implements `entropy`. +""" +struct ClosedFormEntropyZeroGradient <: AbstractEntropyEstimator end + +function estimate_entropy(::ClosedFormEntropyZeroGradient, ::Any, ::Any, q_stop) + return entropy(q_stop) +end + """ ClosedFormEntropy() @@ -9,12 +23,32 @@ Use closed-form expression of entropy[^TL2014][^KTRGB2017]. """ struct ClosedFormEntropy <: AbstractEntropyEstimator end -maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q - -function estimate_entropy(::ClosedFormEntropy, ::Any, q) +function estimate_entropy(::ClosedFormEntropy, ::Any, q, q_stop) return entropy(q) end +""" + MonteCarloEntropy() + +Monte Carlo estimation of the entropy. + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +""" +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +function estimate_entropy( + ::MonteCarloEntropy, + mc_samples::AbstractMatrix, + q, + q_stop, +) + return mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) + end +end + """ StickingTheLandingEntropy() @@ -26,14 +60,35 @@ The "sticking the landing" entropy estimator[^RWD2017]. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -struct MonteCarloEntropy <: AbstractEntropyEstimator end +function estimate_entropy( + ::StickingTheLandingEntropy, + mc_samples::AbstractMatrix, + q, + q_stop, +) + return mean(eachcol(mc_samples)) do mc_sample + -logpdf(q_stop, mc_sample) + end +end -maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop +""" + StickingTheLandingEntropyZeroGradient() + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +- The variational approximation implements `entropy`. +""" +struct StickingTheLandingEntropyZeroGradient <: AbstractEntropyEstimator end function estimate_entropy( - ::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q + ::Union{MonteCarloEntropy,StickingTheLandingEntropyZeroGradient}, + mc_samples::AbstractMatrix, + q, + q_stop, ) - mean(eachcol(mc_samples)) do mc_sample - -logpdf(q, mc_sample) + entropy_stl = mean(eachcol(mc_samples)) do mc_sample + -logpdf(q_stop, mc_sample) end + return entropy_stl - entropy(q) + entropy(q_stop) end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index dc6b772c1..c6aa9135c 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -67,13 +67,6 @@ function Base.show(io::IO, obj::RepGradELBO) return print(io, ")") end -function estimate_entropy_maybe_stl( - entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop -) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - return estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -98,7 +91,7 @@ function reparam_with_entropy( rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + entropy = estimate_entropy(ent_est, samples, q, q_stop) return samples, entropy end diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index 68aac072a..b2dd8491a 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -9,11 +9,17 @@ Optimisers.@def struct ClipScale <: AbstractOperator epsilon = 1e-5 end -function apply(::ClipScale, family::Type, params, restructure) +function apply(::ClipScale, family::Type, state, params, restructure) return error("`ClipScale` is not defined for the variational family of type $(family).") end -function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) +function apply( + op::ClipScale, + ::Type{<:MvLocationScale}, + state, + params, + restructure, +) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) @@ -26,7 +32,13 @@ function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) return params end -function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure) +function apply( + op::ClipScale, + ::Type{<:MvLocationScaleLowRank}, + state, + params, + restructure, +) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) diff --git a/src/optimize.jl b/src/optimize.jl index 5ceb948ec..1be3181ae 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -92,7 +92,7 @@ function optimize( grad = DiffResults.gradient(grad_buf) opt_st, params = Optimisers.update!(opt_st, params, grad) - params = apply(operator, typeof(q_init), params, restructure) + params = apply(operator, typeof(q_init), opt_st, params, restructure) avg_st = apply(averager, avg_st, params) if !isnothing(callback) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 0b9782fa5..eb367b696 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,3 +1,4 @@ + AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" Dict( :Enzyme => AutoEnzyme(; diff --git a/test/interface/clip_scale.jl b/test/interface/clip_scale.jl index d9a6330ce..e2679f5fa 100644 --- a/test/interface/clip_scale.jl +++ b/test/interface/clip_scale.jl @@ -23,7 +23,10 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) + opt_st = AdvancedVI.maybe_init_optimizer( + NamedTuple(), Optimisers.Descent(1e-2), params + ) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) if isnothing(bijector) @@ -54,7 +57,10 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) + opt_st = AdvancedVI.maybe_init_optimizer( + NamedTuple(), Optimisers.Descent(1e-2), params + ) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) if isnothing(bijector) diff --git a/test/runtests.jl b/test/runtests.jl index 2130d0e7b..36af42bb8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,12 +44,13 @@ include("models/normal.jl") include("models/normallognormal.jl") if TEST_GROUP == "All" || TEST_GROUP == "Interface" - # Interface tests that do not involve testing on Enzyme + #Interface tests that do not involve testing on Enzyme include("interface/optimize.jl") include("interface/rules.jl") include("interface/averaging.jl") include("interface/scoregradelbo.jl") include("interface/clip_scale.jl") + include("interface/proximal_location_scale_entropy.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" @@ -69,6 +70,8 @@ if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme" include("inference/repgradelbo_distributionsad.jl") include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl") + include("inference/repgradelbo_proximal_locationscale.jl") + include("inference/repgradelbo_proximal_locationscale_bijectors.jl") include("inference/scoregradelbo_distributionsad.jl") include("inference/scoregradelbo_locationscale.jl") include("inference/scoregradelbo_locationscale_bijectors.jl") From 79ee37eda735aca6f80bbf160b9a6233988197ba Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:54:50 -0400 Subject: [PATCH 02/21] add missing files --- .../repgradelbo_proximal_locationscale.jl | 113 +++++++++++++++++ ...adelbo_proximal_locationscale_bijectors.jl | 118 ++++++++++++++++++ .../proximal_location_scale_entropy.jl | 60 +++++++++ 3 files changed, 291 insertions(+) create mode 100644 test/inference/repgradelbo_proximal_locationscale.jl create mode 100644 test/inference/repgradelbo_proximal_locationscale_bijectors.jl create mode 100644 test/interface/proximal_location_scale_entropy.jl diff --git a/test/inference/repgradelbo_proximal_locationscale.jl b/test/inference/repgradelbo_proximal_locationscale.jl new file mode 100644 index 000000000..e6fc25971 --- /dev/null +++ b/test/inference/repgradelbo_proximal_locationscale.jl @@ -0,0 +1,113 @@ + +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + +@testset "inference RepGradELBO Proximal VILocationScale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => + RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), + :RepGradELBOStickingTheLanding => + RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), + ), + (adbackname, adtype) in AD_repgradelbo_locationscale + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + η = 1e-3 + opt = DoWG(1.) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.location + L = q_avg.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl new file mode 100644 index 000000000..f8fb4f6ac --- /dev/null +++ b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl @@ -0,0 +1,118 @@ +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + +@testset "inference RepGradELBO Proximal VILocationScale Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => + RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), + :RepGradELBOStickingTheLanding => + RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), + ), + (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + η = 1e-3 + opt = DoWG(1.) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + @testset "convergence" begin + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.dist.location + L = q_avg.dist.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.dist.location + L = q_avg.dist.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.dist.location + L_repl = q_avg.dist.scale + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/interface/proximal_location_scale_entropy.jl new file mode 100644 index 000000000..e0a7dee79 --- /dev/null +++ b/test/interface/proximal_location_scale_entropy.jl @@ -0,0 +1,60 @@ + +@testset "interface ProximalLocationScaleEntropy" begin + @testset "MvLocationScale" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] + + stepsize = 1e-2 + optimizer = Descent(stepsize) + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + L = if covtype == :fullrank + LowerTriangular(Matrix{realtype}(I, d, d)) + elseif covtype == :meanfield + Diagonal(ones(realtype, d)) + end + q = if covtype == :fullrank + FullRankGaussian(μ, L) + elseif covtype == :meanfield + MeanFieldGaussian(μ, L) + end + q = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + + # The proximal operator for the entropy of a location scale distribution + # solves the subproblem: + # + # argmin_{L} - logabsdet(L) + 1/(2η) norm(ab2, L - L') + # + # for some fixed L' with respect to L over the set of triangular matrices + # that have strictly positive eigenvalues. + # + # The solution L to this convex program is the solution to + # + # ∇logabsdet(L) = ∇ 1/(2η) norm(abs2, L - L') . + # + # This unit test will check that this equation is satisfied. + + params, re = Optimisers.destructure(q) + opt_st = AdvancedVI.maybe_init_optimizer(NamedTuple(), optimizer, params) + params′ = AdvancedVI.apply( + ProximalLocationScaleEntropy(), typeof(q), opt_st, params, re + ) + + q′ = re(params′) + scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale + + grad_left = Zygote.gradient(L_ -> logabsdet(L_) |> first, scale′) |> only + grad_right = Zygote.gradient(L_ -> sum(abs2, L_ - L)/(2*stepsize), scale′) |> only + + @test grad_left ≈ grad_right + end + end +end From 7d8ea1f7c8e3d5740516ac8e4b796021f02fd61f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:55:31 -0400 Subject: [PATCH 03/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/optimization/clip_scale.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index b2dd8491a..f1ed161e2 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -32,13 +32,7 @@ function apply( return params end -function apply( - op::ClipScale, - ::Type{<:MvLocationScaleLowRank}, - state, - params, - restructure, -) +function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, state, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) From f7a407995aff6e0ea189e1e22215a443cb914551 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:55:37 -0400 Subject: [PATCH 04/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/objectives/elbo/entropy.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index b929af9d4..64c689fd7 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -38,12 +38,7 @@ Monte Carlo estimation of the entropy. """ struct MonteCarloEntropy <: AbstractEntropyEstimator end -function estimate_entropy( - ::MonteCarloEntropy, - mc_samples::AbstractMatrix, - q, - q_stop, -) +function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, q_stop) return mean(eachcol(mc_samples)) do mc_sample -logpdf(q, mc_sample) end From 9d2a6949c90338e90d54c79c39ff6ebc3492aebe Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:55:42 -0400 Subject: [PATCH 05/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/optimization/clip_scale.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index f1ed161e2..5a4c52f60 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -13,13 +13,7 @@ function apply(::ClipScale, family::Type, state, params, restructure) return error("`ClipScale` is not defined for the variational family of type $(family).") end -function apply( - op::ClipScale, - ::Type{<:MvLocationScale}, - state, - params, - restructure, -) +function apply(op::ClipScale, ::Type{<:MvLocationScale}, state, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) From d960dc92151433d699d4a81a0163cf7630f07864 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:55:48 -0400 Subject: [PATCH 06/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/objectives/elbo/entropy.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 64c689fd7..17b3972cd 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -56,10 +56,7 @@ The "sticking the landing" entropy estimator[^RWD2017]. struct StickingTheLandingEntropy <: AbstractEntropyEstimator end function estimate_entropy( - ::StickingTheLandingEntropy, - mc_samples::AbstractMatrix, - q, - q_stop, + ::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q, q_stop ) return mean(eachcol(mc_samples)) do mc_sample -logpdf(q_stop, mc_sample) From 11baeb92971bee20ddc252679686ac1cbf94200f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:58:14 -0400 Subject: [PATCH 07/21] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/inference/repgradelbo_proximal_locationscale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/repgradelbo_proximal_locationscale.jl b/test/inference/repgradelbo_proximal_locationscale.jl index e6fc25971..620cceca7 100644 --- a/test/inference/repgradelbo_proximal_locationscale.jl +++ b/test/inference/repgradelbo_proximal_locationscale.jl @@ -36,7 +36,7 @@ end T = 1000 η = 1e-3 - opt = DoWG(1.) + opt = DoWG(1.0) # For small enough η, the error of SGD, Δλ, is bounded as # Δλ ≤ ρ^T Δλ0 + O(η), From ea4f86cadd806abf3f94a3ffe2984ce8d3ee862e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:58:24 -0400 Subject: [PATCH 08/21] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/inference/repgradelbo_proximal_locationscale_bijectors.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl index f8fb4f6ac..169113571 100644 --- a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl @@ -35,7 +35,7 @@ end T = 1000 η = 1e-3 - opt = DoWG(1.) + opt = DoWG(1.0) b = Bijectors.bijector(model) b⁻¹ = inverse(b) From 09b81d711b9c6f36857de6c1eac7b17624c504a0 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:58:29 -0400 Subject: [PATCH 09/21] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface/proximal_location_scale_entropy.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/interface/proximal_location_scale_entropy.jl index e0a7dee79..118a04bd8 100644 --- a/test/interface/proximal_location_scale_entropy.jl +++ b/test/interface/proximal_location_scale_entropy.jl @@ -51,8 +51,9 @@ q′ = re(params′) scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale - grad_left = Zygote.gradient(L_ -> logabsdet(L_) |> first, scale′) |> only - grad_right = Zygote.gradient(L_ -> sum(abs2, L_ - L)/(2*stepsize), scale′) |> only + grad_left = only(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) + grad_right = + only(Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′)) @test grad_left ≈ grad_right end From 317986fa3808ab2e965e942362cf2889aff3864f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 16:59:23 -0400 Subject: [PATCH 10/21] increment version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index af8b16047..f1a0e255a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.3.1" +version = "0.3.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From fb9ec58a4ed9704feef7de70b280f037be39200b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 17:00:05 -0400 Subject: [PATCH 11/21] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface/proximal_location_scale_entropy.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/interface/proximal_location_scale_entropy.jl index 118a04bd8..2f53ec858 100644 --- a/test/interface/proximal_location_scale_entropy.jl +++ b/test/interface/proximal_location_scale_entropy.jl @@ -52,8 +52,9 @@ scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale grad_left = only(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) - grad_right = - only(Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′)) + grad_right = only( + Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) + ) @test grad_left ≈ grad_right end From 010e1f406edb195479e7930f9718c7f46356b9ad Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 17:02:21 -0400 Subject: [PATCH 12/21] improve docstring for zero gradient entropy estimators --- src/objectives/elbo/entropy.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 17b3972cd..376904c39 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -3,6 +3,7 @@ ClosedFormEntropyZeroGradient() Use closed-form expression of entropy but detach it from the AD graph. +This is expected to be used only with `ProximalLocationScaleEntropy`. # Requirements - The variational approximation implements `entropy`. @@ -66,6 +67,9 @@ end """ StickingTheLandingEntropyZeroGradient() +The "sticking the landing" entropy estimator[^RWD2017] but modified to have a gradient of mean zero. +This is expected to be used only with `ProximalLocationScaleEntropy`. + # Requirements - The variational approximation `q` implements `logpdf`. - `logpdf(q, η)` must be differentiable by the selected AD framework. From 3f187e8115029147381c67a0eeadffe0e36c002c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 17:17:26 -0400 Subject: [PATCH 13/21] add missing file --- .../proximal_location_scale_entropy.jl | 62 +++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/optimization/proximal_location_scale_entropy.jl diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl new file mode 100644 index 000000000..648e45186 --- /dev/null +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -0,0 +1,62 @@ + +""" + ProximalLocationScaleEntropy() + +Proximal operator for the entropy of a location-scale distribution, which is defined as +```math + \\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} \\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert , +``` +where \$\\gamma_t\$ is the stepsize the optimizer used with the proximal operator. +This assumes the variational family is `<:VILocationScale` and the optimizer is one of the following: +- `DoG` +- `DoWG` +- `Descent` + +For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ingore the entropy term. +That is, the `entropy` keyword argument of `RepGradELBO` muse be one of the following: +- `ClosedFormEntropyZeroGradient` +- `StickingTheLandingEntropyZeroGradient` +""" +struct ProximalLocationScaleEntropy <: AbstractOperator end + +function apply(::ProximalLocationScaleEntropy, family, state, params, restructure) + return error("`ProximalLocationScaleEntropy` only supports `<:MvLocationScale`.") +end + +function stepsize_from_optimizer_state(rule::Optimisers.AbstractRule, state) + return error( + "`ProximalLocationScaleEntropy` does not support optimization rule $(typeof(rule))." + ) +end + +stepsize_from_optimizer_state(rule::Descent, ::Any) = rule.eta + +function stepsize_from_optimizer_state(::DoG, state) + _, v, r = state + return r / sqrt(v) +end + +function stepsize_from_optimizer_state(::DoWG, state) + _, v, r = state + return r * r / sqrt(v) +end + +function apply( + ::ProximalLocationScaleEntropy, + ::Type{<:MvLocationScale}, + leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S}, + params, + restructure, +) where {S} + q = restructure(params) + + stepsize = stepsize_from_optimizer_state(leaf.rule, leaf.state) + diag_idx = diagind(q.scale) + scale_diag = q.scale[diag_idx] + @. q.scale[diag_idx] = + scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) + + params, _ = Optimisers.destructure(q) + + return params +end From 7f017127a10b924e89e0347044915382abafaac6 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 17:17:30 -0400 Subject: [PATCH 14/21] add documentation for proximal operator --- docs/src/elbo/repgradelbo.md | 1 + docs/src/optimization.md | 11 ++++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index ccc69a97f..1e67667de 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -266,6 +266,7 @@ Furthermore, in a lot of cases, a low-accuracy solution may be sufficient. [^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. [^KMG2024]: Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR. + ## Advanced Usage There are two major ways to customize the behavior of `RepGradELBO` diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 422c9db86..957cb38df 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -33,11 +33,20 @@ For this, an operator acting on the parameters can be supplied via the `operato ### [`ClipScale`](@id clipscale) -For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. +For the location-scale family, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. To ensure this, we provide the following projection operator: ```@docs ClipScale ``` +### [`ProximalLocationScaleEntropy`](@id proximalocationscaleentropy) + +ELBO maximization with the location-scale family tends to be unstable when the scale has small eigenvalues or the stepsize is large. +To remedy this, a proximal operator of the entropy[^D2020] can be used. + +```@docs +ProximalLocationScaleEntropy +``` + [^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. From 780b8506223a3478629d68d7d0d447c83d0814de Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 17:43:58 -0400 Subject: [PATCH 15/21] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/elbo/repgradelbo.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/elbo/repgradelbo.md b/docs/src/elbo/repgradelbo.md index 1e67667de..ccc69a97f 100644 --- a/docs/src/elbo/repgradelbo.md +++ b/docs/src/elbo/repgradelbo.md @@ -266,7 +266,6 @@ Furthermore, in a lot of cases, a low-accuracy solution may be sufficient. [^RWD2017]: Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. [^KMG2024]: Kim, K., Ma, Y., & Gardner, J. (2024). Linear Convergence of Black-Box Variational Inference: Should We Stick the Landing?. In International Conference on Artificial Intelligence and Statistics (pp. 235-243). PMLR. - ## Advanced Usage There are two major ways to customize the behavior of `RepGradELBO` From ebb07f0e7de85df5b884d0e1f2fe7793154dc665 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:07:51 -0400 Subject: [PATCH 16/21] fix improve type stability --- ext/AdvancedVIBijectorsExt.jl | 2 +- src/optimization/proximal_location_scale_entropy.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 7433a190a..67818e94e 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -43,7 +43,7 @@ function AdvancedVI.apply( end function AdvancedVI.apply( - ::ProximalLocationScaleEntropy, + ::AdvancedVI.ProximalLocationScaleEntropy, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S}, params, diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl index 648e45186..b5b58632b 100644 --- a/src/optimization/proximal_location_scale_entropy.jl +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -54,7 +54,7 @@ function apply( diag_idx = diagind(q.scale) scale_diag = q.scale[diag_idx] @. q.scale[diag_idx] = - scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) + scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2 params, _ = Optimisers.destructure(q) From d70a7207d6367135fda0587dc49c0b78f345a996 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:17:00 -0400 Subject: [PATCH 17/21] apply formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/optimization/proximal_location_scale_entropy.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl index b5b58632b..c478ef600 100644 --- a/src/optimization/proximal_location_scale_entropy.jl +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -53,8 +53,7 @@ function apply( stepsize = stepsize_from_optimizer_state(leaf.rule, leaf.state) diag_idx = diagind(q.scale) scale_diag = q.scale[diag_idx] - @. q.scale[diag_idx] = - scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2 + @. q.scale[diag_idx] = scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2 params, _ = Optimisers.destructure(q) From 188541b71323114565130039770f81b288a215bb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 28 Mar 2025 14:41:47 -0400 Subject: [PATCH 18/21] fix typo in doctring Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- src/optimization/proximal_location_scale_entropy.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl index c478ef600..33946527b 100644 --- a/src/optimization/proximal_location_scale_entropy.jl +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -12,7 +12,7 @@ This assumes the variational family is `<:VILocationScale` and the optimizer is - `DoWG` - `Descent` -For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ingore the entropy term. +For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ignore the entropy term. That is, the `entropy` keyword argument of `RepGradELBO` muse be one of the following: - `ClosedFormEntropyZeroGradient` - `StickingTheLandingEntropyZeroGradient` From cc2091138f0482ef440bfc548ba39b091e9a8486 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 28 Mar 2025 14:42:17 -0400 Subject: [PATCH 19/21] fix typo in comment Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 36af42bb8..495af19c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -44,7 +44,7 @@ include("models/normal.jl") include("models/normallognormal.jl") if TEST_GROUP == "All" || TEST_GROUP == "Interface" - #Interface tests that do not involve testing on Enzyme + # Interface tests that do not involve testing on Enzyme include("interface/optimize.jl") include("interface/rules.jl") include("interface/averaging.jl") From 88d3665398cc2a774774d7afa92b74fe8bcc8b97 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 28 Mar 2025 14:42:38 -0400 Subject: [PATCH 20/21] apply code review comments --- Project.toml | 2 +- src/AdvancedVI.jl | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index f1a0e255a..6c83cebb5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.3.2" +version = "0.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 16373e1d3..85bbf2baf 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -265,13 +265,14 @@ export NoAveraging, PolynomialAveraging abstract type AbstractOperator end """ - apply(op::AbstractOperator, family, rule, state, params, restructure) + apply(op::AbstractOperator, family, rule, opt_state, params, restructure) Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator. # Arguments - `op::AbstractOperator`: Operator operating on the parameters `params`. - `family::Type`: Type of the variational approximation `restructure(params)`. +- `opt_state`: State of the optimizer. - `params`: Variational parameters. - `restructure`: Function that reconstructs the variational approximation from `params`. From d1eb2d3f8a9a48e50dc45967c5952c68487fb585 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 28 Mar 2025 14:44:47 -0400 Subject: [PATCH 21/21] bump compat bound for subprojects --- bench/Project.toml | 2 +- docs/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/bench/Project.toml b/bench/Project.toml index 6511dcc20..703c546bd 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" -AdvancedVI = "0.3" +AdvancedVI = "0.3, 0.4" BenchmarkTools = "1" Bijectors = "0.13, 0.14, 0.15" Distributions = "0.25.111" diff --git a/docs/Project.toml b/docs/Project.toml index eb49ea71e..4466b49b7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ADTypes = "1" -AdvancedVI = "0.3, 0.2" +AdvancedVI = "0.4" Bijectors = "0.13.6, 0.14, 0.15" Distributions = "0.25" Documenter = "1"