diff --git a/Project.toml b/Project.toml index f1a0e255a..6c83cebb5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.3.2" +version = "0.4.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/bench/Project.toml b/bench/Project.toml index 6511dcc20..703c546bd 100644 --- a/bench/Project.toml +++ b/bench/Project.toml @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] ADTypes = "1" -AdvancedVI = "0.3" +AdvancedVI = "0.3, 0.4" BenchmarkTools = "1" Bijectors = "0.13, 0.14, 0.15" Distributions = "0.25.111" diff --git a/docs/Project.toml b/docs/Project.toml index eb49ea71e..4466b49b7 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -15,7 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] ADTypes = "1" -AdvancedVI = "0.3, 0.2" +AdvancedVI = "0.4" Bijectors = "0.13.6, 0.14, 0.15" Distributions = "0.25" Documenter = "1" diff --git a/docs/src/optimization.md b/docs/src/optimization.md index 422c9db86..957cb38df 100644 --- a/docs/src/optimization.md +++ b/docs/src/optimization.md @@ -33,11 +33,20 @@ For this, an operator acting on the parameters can be supplied via the `operato ### [`ClipScale`](@id clipscale) -For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. +For the location-scale family, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020]. To ensure this, we provide the following projection operator: ```@docs ClipScale ``` +### [`ProximalLocationScaleEntropy`](@id proximalocationscaleentropy) + +ELBO maximization with the location-scale family tends to be unstable when the scale has small eigenvalues or the stepsize is large. +To remedy this, a proximal operator of the entropy[^D2020] can be used. + +```@docs +ProximalLocationScaleEntropy +``` + [^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*. diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 8f6205167..67818e94e 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -9,6 +9,7 @@ using Random function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + state, params, restructure, ) @@ -27,6 +28,7 @@ end function AdvancedVI.apply( op::ClipScale, ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}}, + state, params, restructure, ) @@ -40,6 +42,26 @@ function AdvancedVI.apply( return params end +function AdvancedVI.apply( + ::AdvancedVI.ProximalLocationScaleEntropy, + ::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}}, + leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S}, + params, + restructure, +) where {S} + q = restructure(params) + + stepsize = AdvancedVI.stepsize_from_optimizer_state(leaf.rule, leaf.state) + diag_idx = diagind(q.dist.scale) + scale_diag = q.dist.scale[diag_idx] + @. q.dist.scale[diag_idx] = + scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) + + params, _ = Optimisers.destructure(q) + + return params +end + function AdvancedVI.reparam_with_entropy( rng::Random.AbstractRNG, q::Bijectors.TransformedDistribution, diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 18135cde7..85bbf2baf 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -177,13 +177,14 @@ function estimate_gradient! end abstract type AbstractEntropyEstimator end """ - estimate_entropy(entropy_estimator, mc_samples, q) + estimate_entropy(entropy_estimator, mc_samples, q, q_stop) Estimate the entropy of `q`. # Arguments - `entropy_estimator`: Entropy estimation strategy. - `q`: Variational approximation. +- `q_stop`: Variational approximation with detached from the automatic differentiation graph. - `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) # Returns @@ -192,7 +193,12 @@ Estimate the entropy of `q`. function estimate_entropy end export RepGradELBO, - ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy + ScoreGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy, + ClosedFormEntropyZeroGradient, + StickingTheLandingEntropyZeroGradient include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") @@ -259,20 +265,21 @@ export NoAveraging, PolynomialAveraging abstract type AbstractOperator end """ - apply(op::AbstractOperator, family, params, restructure) + apply(op::AbstractOperator, family, rule, opt_state, params, restructure) Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator. # Arguments - `op::AbstractOperator`: Operator operating on the parameters `params`. - `family::Type`: Type of the variational approximation `restructure(params)`. +- `opt_state`: State of the optimizer. - `params`: Variational parameters. - `restructure`: Function that reconstructs the variational approximation from `params`. # Returns - `oped_params`: Parameters resulting from applying the operator. """ -function apply(::AbstractOperator, ::Type, ::Any, ::Any) end +function apply(::AbstractOperator, ::Type, ::Optimisers.AbstractRule, ::Any, ::Any, ::Any) end """ IdentityOperator() @@ -281,11 +288,12 @@ Identity operator. """ struct IdentityOperator <: AbstractOperator end -apply(::IdentityOperator, ::Type, params, restructure) = params +apply(::IdentityOperator, ::Type, opt_st, params, restructure) = params include("optimization/clip_scale.jl") +include("optimization/proximal_location_scale_entropy.jl") -export IdentityOperator, ClipScale +export IdentityOperator, ClipScale, ProximalLocationScaleEntropy # Main optimization routine function optimize end diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl index 210b49ca9..376904c39 100644 --- a/src/objectives/elbo/entropy.jl +++ b/src/objectives/elbo/entropy.jl @@ -1,4 +1,19 @@ +""" + ClosedFormEntropyZeroGradient() + +Use closed-form expression of entropy but detach it from the AD graph. +This is expected to be used only with `ProximalLocationScaleEntropy`. + +# Requirements +- The variational approximation implements `entropy`. +""" +struct ClosedFormEntropyZeroGradient <: AbstractEntropyEstimator end + +function estimate_entropy(::ClosedFormEntropyZeroGradient, ::Any, ::Any, q_stop) + return entropy(q_stop) +end + """ ClosedFormEntropy() @@ -9,12 +24,27 @@ Use closed-form expression of entropy[^TL2014][^KTRGB2017]. """ struct ClosedFormEntropy <: AbstractEntropyEstimator end -maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q - -function estimate_entropy(::ClosedFormEntropy, ::Any, q) +function estimate_entropy(::ClosedFormEntropy, ::Any, q, q_stop) return entropy(q) end +""" + MonteCarloEntropy() + +Monte Carlo estimation of the entropy. + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +""" +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, q_stop) + return mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) + end +end + """ StickingTheLandingEntropy() @@ -26,14 +56,35 @@ The "sticking the landing" entropy estimator[^RWD2017]. """ struct StickingTheLandingEntropy <: AbstractEntropyEstimator end -struct MonteCarloEntropy <: AbstractEntropyEstimator end +function estimate_entropy( + ::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q, q_stop +) + return mean(eachcol(mc_samples)) do mc_sample + -logpdf(q_stop, mc_sample) + end +end -maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop +""" + StickingTheLandingEntropyZeroGradient() + +The "sticking the landing" entropy estimator[^RWD2017] but modified to have a gradient of mean zero. +This is expected to be used only with `ProximalLocationScaleEntropy`. + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. +- The variational approximation implements `entropy`. +""" +struct StickingTheLandingEntropyZeroGradient <: AbstractEntropyEstimator end function estimate_entropy( - ::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q + ::Union{MonteCarloEntropy,StickingTheLandingEntropyZeroGradient}, + mc_samples::AbstractMatrix, + q, + q_stop, ) - mean(eachcol(mc_samples)) do mc_sample - -logpdf(q, mc_sample) + entropy_stl = mean(eachcol(mc_samples)) do mc_sample + -logpdf(q_stop, mc_sample) end + return entropy_stl - entropy(q) + entropy(q_stop) end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index dc6b772c1..c6aa9135c 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -67,13 +67,6 @@ function Base.show(io::IO, obj::RepGradELBO) return print(io, ")") end -function estimate_entropy_maybe_stl( - entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop -) - q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) - return estimate_entropy(entropy_estimator, samples, q_maybe_stop) -end - function estimate_energy_with_samples(prob, samples) return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) end @@ -98,7 +91,7 @@ function reparam_with_entropy( rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator ) samples = rand(rng, q, n_samples) - entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + entropy = estimate_entropy(ent_est, samples, q, q_stop) return samples, entropy end diff --git a/src/optimization/clip_scale.jl b/src/optimization/clip_scale.jl index 68aac072a..5a4c52f60 100644 --- a/src/optimization/clip_scale.jl +++ b/src/optimization/clip_scale.jl @@ -9,11 +9,11 @@ Optimisers.@def struct ClipScale <: AbstractOperator epsilon = 1e-5 end -function apply(::ClipScale, family::Type, params, restructure) +function apply(::ClipScale, family::Type, state, params, restructure) return error("`ClipScale` is not defined for the variational family of type $(family).") end -function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) +function apply(op::ClipScale, ::Type{<:MvLocationScale}, state, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) @@ -26,7 +26,7 @@ function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure) return params end -function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure) +function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, state, params, restructure) q = restructure(params) ϵ = convert(eltype(params), op.epsilon) diff --git a/src/optimization/proximal_location_scale_entropy.jl b/src/optimization/proximal_location_scale_entropy.jl new file mode 100644 index 000000000..33946527b --- /dev/null +++ b/src/optimization/proximal_location_scale_entropy.jl @@ -0,0 +1,61 @@ + +""" + ProximalLocationScaleEntropy() + +Proximal operator for the entropy of a location-scale distribution, which is defined as +```math + \\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} \\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert , +``` +where \$\\gamma_t\$ is the stepsize the optimizer used with the proximal operator. +This assumes the variational family is `<:VILocationScale` and the optimizer is one of the following: +- `DoG` +- `DoWG` +- `Descent` + +For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ignore the entropy term. +That is, the `entropy` keyword argument of `RepGradELBO` muse be one of the following: +- `ClosedFormEntropyZeroGradient` +- `StickingTheLandingEntropyZeroGradient` +""" +struct ProximalLocationScaleEntropy <: AbstractOperator end + +function apply(::ProximalLocationScaleEntropy, family, state, params, restructure) + return error("`ProximalLocationScaleEntropy` only supports `<:MvLocationScale`.") +end + +function stepsize_from_optimizer_state(rule::Optimisers.AbstractRule, state) + return error( + "`ProximalLocationScaleEntropy` does not support optimization rule $(typeof(rule))." + ) +end + +stepsize_from_optimizer_state(rule::Descent, ::Any) = rule.eta + +function stepsize_from_optimizer_state(::DoG, state) + _, v, r = state + return r / sqrt(v) +end + +function stepsize_from_optimizer_state(::DoWG, state) + _, v, r = state + return r * r / sqrt(v) +end + +function apply( + ::ProximalLocationScaleEntropy, + ::Type{<:MvLocationScale}, + leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S}, + params, + restructure, +) where {S} + q = restructure(params) + + stepsize = stepsize_from_optimizer_state(leaf.rule, leaf.state) + diag_idx = diagind(q.scale) + scale_diag = q.scale[diag_idx] + @. q.scale[diag_idx] = scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2 + + params, _ = Optimisers.destructure(q) + + return params +end diff --git a/src/optimize.jl b/src/optimize.jl index 5ceb948ec..1be3181ae 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -92,7 +92,7 @@ function optimize( grad = DiffResults.gradient(grad_buf) opt_st, params = Optimisers.update!(opt_st, params, grad) - params = apply(operator, typeof(q_init), params, restructure) + params = apply(operator, typeof(q_init), opt_st, params, restructure) avg_st = apply(averager, avg_st, params) if !isnothing(callback) diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 0b9782fa5..eb367b696 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -1,3 +1,4 @@ + AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme" Dict( :Enzyme => AutoEnzyme(; diff --git a/test/inference/repgradelbo_proximal_locationscale.jl b/test/inference/repgradelbo_proximal_locationscale.jl new file mode 100644 index 000000000..620cceca7 --- /dev/null +++ b/test/inference/repgradelbo_proximal_locationscale.jl @@ -0,0 +1,113 @@ + +AD_repgradelbo_locationscale = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + +@testset "inference RepGradELBO Proximal VILocationScale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:Normal => normal_meanfield, :Normal => normal_fullrank), + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => + RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), + :RepGradELBOStickingTheLanding => + RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), + ), + (adbackname, adtype) in AD_repgradelbo_locationscale + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + η = 1e-3 + opt = DoWG(1.0) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.location + L = q_avg.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/inference/repgradelbo_proximal_locationscale_bijectors.jl b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl new file mode 100644 index 000000000..169113571 --- /dev/null +++ b/test/inference/repgradelbo_proximal_locationscale_bijectors.jl @@ -0,0 +1,118 @@ +AD_repgradelbo_locationscale_bijectors = if TEST_GROUP == "Enzyme" + Dict( + :Enzyme => AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const, + ), + ) +else + Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Mooncake => AutoMooncake(; config=Mooncake.Config()), + ) +end + +@testset "inference RepGradELBO Proximal VILocationScale Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in + [Float64, Float32], + (modelname, modelconstr) in + Dict(:NormalLogNormalMeanField => normallognormal_meanfield), + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => + RepGradELBO(10; entropy=ClosedFormEntropyZeroGradient()), + :RepGradELBOStickingTheLanding => + RepGradELBO(10; entropy=StickingTheLandingEntropyZeroGradient()), + ), + (adbackname, adtype) in AD_repgradelbo_locationscale_bijectors + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + η = 1e-3 + opt = DoWG(1.0) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) + + # For small enough η, the error of SGD, Δλ, is bounded as + # Δλ ≤ ρ^T Δλ0 + O(η), + # where ρ = 1 - ημ, μ is the strong convexity constant. + contraction_rate = 1 - η * strong_convexity + + @testset "convergence" begin + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + + μ = q_avg.dist.location + L = q_avg.dist.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ contraction_rate^(T / 2) * Δλ0 + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ = q_avg.dist.location + L = q_avg.dist.scale + + rng_repl = StableRNG(seed) + q_avg, _, stats, _ = optimize( + rng_repl, + model, + objective, + q0_z, + T; + optimizer=opt, + averager=PolynomialAveraging(), + operator=ProximalLocationScaleEntropy(), + show_progress=PROGRESS, + adtype=adtype, + ) + μ_repl = q_avg.dist.location + L_repl = q_avg.dist.scale + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/interface/clip_scale.jl b/test/interface/clip_scale.jl index d9a6330ce..e2679f5fa 100644 --- a/test/interface/clip_scale.jl +++ b/test/interface/clip_scale.jl @@ -23,7 +23,10 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) + opt_st = AdvancedVI.maybe_init_optimizer( + NamedTuple(), Optimisers.Descent(1e-2), params + ) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) if isnothing(bijector) @@ -54,7 +57,10 @@ end params, re = Optimisers.destructure(q) - params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), params, re) + opt_st = AdvancedVI.maybe_init_optimizer( + NamedTuple(), Optimisers.Descent(1e-2), params + ) + params′ = AdvancedVI.apply(ClipScale(ϵ), typeof(q), opt_st, params, re) q′ = re(params′) if isnothing(bijector) diff --git a/test/interface/proximal_location_scale_entropy.jl b/test/interface/proximal_location_scale_entropy.jl new file mode 100644 index 000000000..2f53ec858 --- /dev/null +++ b/test/interface/proximal_location_scale_entropy.jl @@ -0,0 +1,62 @@ + +@testset "interface ProximalLocationScaleEntropy" begin + @testset "MvLocationScale" begin + @testset "$(string(covtype)) $(realtype) $(bijector)" for covtype in + [:meanfield, :fullrank], + realtype in [Float32, Float64], + bijector in [nothing, :identity] + + stepsize = 1e-2 + optimizer = Descent(stepsize) + + d = 5 + μ = zeros(realtype, d) + ϵ = sqrt(realtype(0.5)) + L = if covtype == :fullrank + LowerTriangular(Matrix{realtype}(I, d, d)) + elseif covtype == :meanfield + Diagonal(ones(realtype, d)) + end + q = if covtype == :fullrank + FullRankGaussian(μ, L) + elseif covtype == :meanfield + MeanFieldGaussian(μ, L) + end + q = if isnothing(bijector) + q + else + Bijectors.TransformedDistribution(q, identity) + end + + # The proximal operator for the entropy of a location scale distribution + # solves the subproblem: + # + # argmin_{L} - logabsdet(L) + 1/(2η) norm(ab2, L - L') + # + # for some fixed L' with respect to L over the set of triangular matrices + # that have strictly positive eigenvalues. + # + # The solution L to this convex program is the solution to + # + # ∇logabsdet(L) = ∇ 1/(2η) norm(abs2, L - L') . + # + # This unit test will check that this equation is satisfied. + + params, re = Optimisers.destructure(q) + opt_st = AdvancedVI.maybe_init_optimizer(NamedTuple(), optimizer, params) + params′ = AdvancedVI.apply( + ProximalLocationScaleEntropy(), typeof(q), opt_st, params, re + ) + + q′ = re(params′) + scale′ = isnothing(bijector) ? q′.scale : q′.dist.scale + + grad_left = only(Zygote.gradient(L_ -> first(logabsdet(L_)), scale′)) + grad_right = only( + Zygote.gradient(L_ -> sum(abs2, L_ - L) / (2 * stepsize), scale′) + ) + + @test grad_left ≈ grad_right + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2130d0e7b..495af19c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,7 @@ if TEST_GROUP == "All" || TEST_GROUP == "Interface" include("interface/averaging.jl") include("interface/scoregradelbo.jl") include("interface/clip_scale.jl") + include("interface/proximal_location_scale_entropy.jl") end if TEST_GROUP == "All" || TEST_GROUP == "Interface" || TEST_GROUP == "Enzyme" @@ -69,6 +70,8 @@ if TEST_GROUP == "All" || TEST_GROUP == "Inference" || TEST_GROUP == "Enzyme" include("inference/repgradelbo_distributionsad.jl") include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl") + include("inference/repgradelbo_proximal_locationscale.jl") + include("inference/repgradelbo_proximal_locationscale_bijectors.jl") include("inference/scoregradelbo_distributionsad.jl") include("inference/scoregradelbo_locationscale.jl") include("inference/scoregradelbo_locationscale_bijectors.jl")