Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
51e5750
add proximal operator for the entropy of location-scale families
Red-Portal Mar 14, 2025
79ee37e
add missing files
Red-Portal Mar 14, 2025
7d8ea1f
apply formatter
Red-Portal Mar 14, 2025
f7a4079
apply formatter
Red-Portal Mar 14, 2025
9d2a694
apply formatter
Red-Portal Mar 14, 2025
d960dc9
apply formatter
Red-Portal Mar 14, 2025
11baeb9
run formatter
Red-Portal Mar 14, 2025
ea4f86c
run formatter
Red-Portal Mar 14, 2025
09b81d7
run formatter
Red-Portal Mar 14, 2025
317986f
increment version
Red-Portal Mar 14, 2025
fb9ec58
fix formatting
Red-Portal Mar 14, 2025
010e1f4
improve docstring for zero gradient entropy estimators
Red-Portal Mar 14, 2025
8b147cd
Merge branch 'proximal_entropy_location_scale' of github.com:TuringLa…
Red-Portal Mar 14, 2025
3f187e8
add missing file
Red-Portal Mar 14, 2025
7f01712
add documentation for proximal operator
Red-Portal Mar 14, 2025
780b850
run formatter
Red-Portal Mar 14, 2025
ebb07f0
fix improve type stability
Red-Portal Mar 14, 2025
cff3416
Merge branch 'proximal_entropy_location_scale' of github.com:TuringLa…
Red-Portal Mar 14, 2025
d70a720
apply formatter
Red-Portal Mar 14, 2025
ce33a00
Merge branch 'main' of github.com:TuringLang/AdvancedVI.jl into proxi…
Red-Portal Mar 28, 2025
3c5cb5a
Merge branch 'proximal_entropy_location_scale' of github.com:TuringLa…
Red-Portal Mar 28, 2025
188541b
fix typo in doctring
Red-Portal Mar 28, 2025
cc20911
fix typo in comment
Red-Portal Mar 28, 2025
88d3665
apply code review comments
Red-Portal Mar 28, 2025
d1eb2d3
bump compat bound for subprojects
Red-Portal Mar 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.3.2"
version = "0.4.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion bench/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
AdvancedVI = "0.3"
AdvancedVI = "0.3, 0.4"
BenchmarkTools = "1"
Bijectors = "0.13, 0.14, 0.15"
Distributions = "0.25.111"
Expand Down
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
ADTypes = "1"
AdvancedVI = "0.3, 0.2"
AdvancedVI = "0.4"
Bijectors = "0.13.6, 0.14, 0.15"
Distributions = "0.25"
Documenter = "1"
Expand Down
11 changes: 10 additions & 1 deletion docs/src/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,20 @@ For this, an operator acting on the parameters can be supplied via the `operato

### [`ClipScale`](@id clipscale)

For the location scale, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
For the location-scale family, it is often the case that optimization is stable only when the smallest eigenvalue of the scale matrix is strictly positive[^D2020].
To ensure this, we provide the following projection operator:

```@docs
ClipScale
```

### [`ProximalLocationScaleEntropy`](@id proximalocationscaleentropy)

ELBO maximization with the location-scale family tends to be unstable when the scale has small eigenvalues or the stepsize is large.
To remedy this, a proximal operator of the entropy[^D2020] can be used.

```@docs
ProximalLocationScaleEntropy
```

[^D2020]: Domke, J. (2020). Provable smoothness guarantees for black-box variational inference. In *International Conference on Machine Learning*.
22 changes: 22 additions & 0 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using Random
function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
state,
params,
restructure,
)
Expand All @@ -27,6 +28,7 @@ end
function AdvancedVI.apply(
op::ClipScale,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScaleLowRank}},
state,
params,
restructure,
)
Expand All @@ -40,6 +42,26 @@ function AdvancedVI.apply(
return params
end

function AdvancedVI.apply(
::AdvancedVI.ProximalLocationScaleEntropy,
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S},
params,
restructure,
) where {S}
q = restructure(params)

stepsize = AdvancedVI.stepsize_from_optimizer_state(leaf.rule, leaf.state)
diag_idx = diagind(q.dist.scale)
scale_diag = q.dist.scale[diag_idx]
@. q.dist.scale[diag_idx] =
scale_diag + 1 / 2 * (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag)

params, _ = Optimisers.destructure(q)

return params
end

function AdvancedVI.reparam_with_entropy(
rng::Random.AbstractRNG,
q::Bijectors.TransformedDistribution,
Expand Down
20 changes: 14 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ function estimate_gradient! end
abstract type AbstractEntropyEstimator end

"""
estimate_entropy(entropy_estimator, mc_samples, q)
estimate_entropy(entropy_estimator, mc_samples, q, q_stop)

Estimate the entropy of `q`.

# Arguments
- `entropy_estimator`: Entropy estimation strategy.
- `q`: Variational approximation.
- `q_stop`: Variational approximation with detached from the automatic differentiation graph.
- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.)

# Returns
Expand All @@ -192,7 +193,12 @@ Estimate the entropy of `q`.
function estimate_entropy end

export RepGradELBO,
ScoreGradELBO, ClosedFormEntropy, StickingTheLandingEntropy, MonteCarloEntropy
ScoreGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy,
ClosedFormEntropyZeroGradient,
StickingTheLandingEntropyZeroGradient

include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")
Expand Down Expand Up @@ -259,20 +265,21 @@ export NoAveraging, PolynomialAveraging
abstract type AbstractOperator end

"""
apply(op::AbstractOperator, family, params, restructure)
apply(op::AbstractOperator, family, rule, opt_state, params, restructure)

Apply operator `op` on the variational parameters `params`. For instance, `op` could be a projection or proximal operator.

# Arguments
- `op::AbstractOperator`: Operator operating on the parameters `params`.
- `family::Type`: Type of the variational approximation `restructure(params)`.
- `opt_state`: State of the optimizer.
- `params`: Variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `params`.

# Returns
- `oped_params`: Parameters resulting from applying the operator.
"""
function apply(::AbstractOperator, ::Type, ::Any, ::Any) end
function apply(::AbstractOperator, ::Type, ::Optimisers.AbstractRule, ::Any, ::Any, ::Any) end

"""
IdentityOperator()
Expand All @@ -281,11 +288,12 @@ Identity operator.
"""
struct IdentityOperator <: AbstractOperator end

apply(::IdentityOperator, ::Type, params, restructure) = params
apply(::IdentityOperator, ::Type, opt_st, params, restructure) = params

include("optimization/clip_scale.jl")
include("optimization/proximal_location_scale_entropy.jl")

export IdentityOperator, ClipScale
export IdentityOperator, ClipScale, ProximalLocationScaleEntropy

# Main optimization routine
function optimize end
Expand Down
67 changes: 59 additions & 8 deletions src/objectives/elbo/entropy.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,19 @@

"""
ClosedFormEntropyZeroGradient()

Use closed-form expression of entropy but detach it from the AD graph.
This is expected to be used only with `ProximalLocationScaleEntropy`.

# Requirements
- The variational approximation implements `entropy`.
"""
struct ClosedFormEntropyZeroGradient <: AbstractEntropyEstimator end

function estimate_entropy(::ClosedFormEntropyZeroGradient, ::Any, ::Any, q_stop)
return entropy(q_stop)
end

"""
ClosedFormEntropy()

Expand All @@ -9,12 +24,27 @@ Use closed-form expression of entropy[^TL2014][^KTRGB2017].
"""
struct ClosedFormEntropy <: AbstractEntropyEstimator end

maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q

function estimate_entropy(::ClosedFormEntropy, ::Any, q)
function estimate_entropy(::ClosedFormEntropy, ::Any, q, q_stop)
return entropy(q)
end

"""
MonteCarloEntropy()

Monte Carlo estimation of the entropy.

# Requirements
- The variational approximation `q` implements `logpdf`.
- `logpdf(q, η)` must be differentiable by the selected AD framework.
"""
struct MonteCarloEntropy <: AbstractEntropyEstimator end

function estimate_entropy(::MonteCarloEntropy, mc_samples::AbstractMatrix, q, q_stop)
return mean(eachcol(mc_samples)) do mc_sample
-logpdf(q, mc_sample)
end
end

"""
StickingTheLandingEntropy()

Expand All @@ -26,14 +56,35 @@ The "sticking the landing" entropy estimator[^RWD2017].
"""
struct StickingTheLandingEntropy <: AbstractEntropyEstimator end

struct MonteCarloEntropy <: AbstractEntropyEstimator end
function estimate_entropy(
::StickingTheLandingEntropy, mc_samples::AbstractMatrix, q, q_stop
)
return mean(eachcol(mc_samples)) do mc_sample
-logpdf(q_stop, mc_sample)
end
end

maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop
"""
StickingTheLandingEntropyZeroGradient()

The "sticking the landing" entropy estimator[^RWD2017] but modified to have a gradient of mean zero.
This is expected to be used only with `ProximalLocationScaleEntropy`.

# Requirements
- The variational approximation `q` implements `logpdf`.
- `logpdf(q, η)` must be differentiable by the selected AD framework.
- The variational approximation implements `entropy`.
"""
struct StickingTheLandingEntropyZeroGradient <: AbstractEntropyEstimator end

function estimate_entropy(
::Union{MonteCarloEntropy,StickingTheLandingEntropy}, mc_samples::AbstractMatrix, q
::Union{MonteCarloEntropy,StickingTheLandingEntropyZeroGradient},
mc_samples::AbstractMatrix,
q,
q_stop,
)
mean(eachcol(mc_samples)) do mc_sample
-logpdf(q, mc_sample)
entropy_stl = mean(eachcol(mc_samples)) do mc_sample
-logpdf(q_stop, mc_sample)
end
return entropy_stl - entropy(q) + entropy(q_stop)
end
9 changes: 1 addition & 8 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,6 @@ function Base.show(io::IO, obj::RepGradELBO)
return print(io, ")")
end

function estimate_entropy_maybe_stl(
entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop
)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
return estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(prob, samples)
return mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end
Expand All @@ -98,7 +91,7 @@ function reparam_with_entropy(
rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator
)
samples = rand(rng, q, n_samples)
entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop)
entropy = estimate_entropy(ent_est, samples, q, q_stop)
return samples, entropy
end

Expand Down
6 changes: 3 additions & 3 deletions src/optimization/clip_scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ Optimisers.@def struct ClipScale <: AbstractOperator
epsilon = 1e-5
end

function apply(::ClipScale, family::Type, params, restructure)
function apply(::ClipScale, family::Type, state, params, restructure)
return error("`ClipScale` is not defined for the variational family of type $(family).")
end

function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure)
function apply(op::ClipScale, ::Type{<:MvLocationScale}, state, params, restructure)
q = restructure(params)
ϵ = convert(eltype(params), op.epsilon)

Expand All @@ -26,7 +26,7 @@ function apply(op::ClipScale, ::Type{<:MvLocationScale}, params, restructure)
return params
end

function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, params, restructure)
function apply(op::ClipScale, ::Type{<:MvLocationScaleLowRank}, state, params, restructure)
q = restructure(params)
ϵ = convert(eltype(params), op.epsilon)

Expand Down
61 changes: 61 additions & 0 deletions src/optimization/proximal_location_scale_entropy.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

"""
ProximalLocationScaleEntropy()

Proximal operator for the entropy of a location-scale distribution, which is defined as
```math
\\mathrm{prox}(\\lambda) = \\argmin_{\\lambda^{\\prime}} - \\mathbb{H}(q_{\\lambda^{\\prime}}) + \\frac{1}{2 \\gamma_t} \\left\\lVert \\lambda - \\lambda^{\\prime} \\right\\rVert ,
```
where \$\\gamma_t\$ is the stepsize the optimizer used with the proximal operator.
This assumes the variational family is `<:VILocationScale` and the optimizer is one of the following:
- `DoG`
- `DoWG`
- `Descent`

For ELBO maximization, since this proximal operator handles the entropy, the gradient estimator for the ELBO must ignore the entropy term.
That is, the `entropy` keyword argument of `RepGradELBO` muse be one of the following:
- `ClosedFormEntropyZeroGradient`
- `StickingTheLandingEntropyZeroGradient`
"""
struct ProximalLocationScaleEntropy <: AbstractOperator end

function apply(::ProximalLocationScaleEntropy, family, state, params, restructure)
return error("`ProximalLocationScaleEntropy` only supports `<:MvLocationScale`.")
end

function stepsize_from_optimizer_state(rule::Optimisers.AbstractRule, state)
return error(
"`ProximalLocationScaleEntropy` does not support optimization rule $(typeof(rule))."
)
end

stepsize_from_optimizer_state(rule::Descent, ::Any) = rule.eta

function stepsize_from_optimizer_state(::DoG, state)
_, v, r = state
return r / sqrt(v)
end

function stepsize_from_optimizer_state(::DoWG, state)
_, v, r = state
return r * r / sqrt(v)
end

function apply(
::ProximalLocationScaleEntropy,
::Type{<:MvLocationScale},
leaf::Optimisers.Leaf{<:Union{<:DoG,<:DoWG,<:Descent},S},
params,
restructure,
) where {S}
q = restructure(params)

stepsize = stepsize_from_optimizer_state(leaf.rule, leaf.state)
diag_idx = diagind(q.scale)
scale_diag = q.scale[diag_idx]
@. q.scale[diag_idx] = scale_diag + (sqrt(scale_diag^2 + 4 * stepsize) - scale_diag) / 2

params, _ = Optimisers.destructure(q)

return params
end
2 changes: 1 addition & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function optimize(

grad = DiffResults.gradient(grad_buf)
opt_st, params = Optimisers.update!(opt_st, params, grad)
params = apply(operator, typeof(q_init), params, restructure)
params = apply(operator, typeof(q_init), opt_st, params, restructure)
avg_st = apply(averager, avg_st, params)

if !isnothing(callback)
Expand Down
1 change: 1 addition & 0 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

AD_repgradelbo_distributionsad = if TEST_GROUP == "Enzyme"
Dict(
:Enzyme => AutoEnzyme(;
Expand Down
Loading
Loading