Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Release 0.5

## Interface Changes

An additional layer of indirection, `AbstractAlgorithms` has been added.
Previously, all variational inference algorithms were assumed to run SGD in parameter space.
This desing however, is proving to be too rigid.
Instead, each algorithm is now assumed to implement three simple interfaces: `init`, `step`, and `output`.
Algorithms that run SGD in parameter space now need to implement the `AbstractVarationalObjective` interface of `ParamSpaceSGD <: AbstractAlgorithms`, which is a general implementation of the new interface.
Therefore, the old behavior of `AdvancedVI` is fully inhereted by `ParamSpaceSGD`.

## Internal Changes

The state of the objectives now use a concrete type.
Related to this, the objective `state` argument in `estimate_gradient!` has been moved to the front to avoid type ambiguities.
4 changes: 2 additions & 2 deletions src/algorithms/paramspacesgd/abstractobjective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ function estimate_objective end
export estimate_objective

"""
estimate_gradient!(rng, obj, adtype, out, params, restructure, obj_state)
estimate_gradient!(rng, obj, adtype, out, obj_state, params, restructure)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`

Expand All @@ -68,9 +68,9 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
- `obj::AbstractVariationalObjective`: Variational objective.
- `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates.
- `obj_state`: Previous state of the objective.
- `params`: Variational parameters to evaluate the gradient on.
- `restructure`: Function that reconstructs the variational approximation from `params`.
- `obj_state`: Previous state of the objective.

# Returns
- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates.
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms/paramspacesgd/paramspacesgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ function step(
params, re = Optimisers.destructure(q)

grad_buf, obj_st, info = estimate_gradient!(
rng, objective, adtype, grad_buf, params, re, obj_st, objargs...
rng, objective, adtype, grad_buf, obj_st, params, re, objargs...
)

grad = DiffResults.gradient(grad_buf)
Expand Down
13 changes: 9 additions & 4 deletions src/algorithms/paramspacesgd/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ struct RepGradELBO{EntropyEst<:AbstractEntropyEstimator} <: AbstractVariationalO
n_samples::Int
end

struct RepGradELBOState{Problem,ObjADPrep}
problem::Problem
obj_ad_prep::ObjADPrep
end

function init(
rng::Random.AbstractRNG,
obj::RepGradELBO,
Expand Down Expand Up @@ -56,7 +61,7 @@ function init(
obj_ad_prep = AdvancedVI._prepare_gradient(
estimate_repgradelbo_ad_forward, adtype, params, aux
)
return (obj_ad_prep=obj_ad_prep, problem=ad_prob)
return RepGradELBOState(ad_prob, obj_ad_prep)
end

function RepGradELBO(n_samples::Int; entropy::AbstractEntropyEstimator=ClosedFormEntropy())
Expand Down Expand Up @@ -143,9 +148,9 @@ function estimate_gradient!(
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
state::RepGradELBOState,
params,
restructure,
state,
args...,
)
(; obj_ad_prep, problem) = state
Expand All @@ -162,6 +167,6 @@ function estimate_gradient!(
estimate_repgradelbo_ad_forward, out, obj_ad_prep, adtype, params, aux
)
nelbo = DiffResults.value(out)
stat = (elbo=(-nelbo),)
return out, state, stat
info = (elbo=(-nelbo),)
return out, state, info
end
9 changes: 7 additions & 2 deletions src/algorithms/paramspacesgd/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ struct ScoreGradELBO <: AbstractVariationalObjective
n_samples::Int
end

struct ScoreGradELBOState{Problem,ObjADPrep}
problem::Problem
obj_ad_prep::ObjADPrep
end

function init(
rng::Random.AbstractRNG,
obj::ScoreGradELBO,
Expand All @@ -31,7 +36,7 @@ function init(
obj_ad_prep = AdvancedVI._prepare_gradient(
estimate_scoregradelbo_ad_forward, adtype, params, aux
)
return (obj_ad_prep=obj_ad_prep, problem=prob)
return ScoreGradELBOState(prob, obj_ad_prep)
end

function Base.show(io::IO, obj::ScoreGradELBO)
Expand Down Expand Up @@ -83,9 +88,9 @@ function AdvancedVI.estimate_gradient!(
obj::ScoreGradELBO,
adtype::ADTypes.AbstractADType,
out::DiffResults.MutableDiffResult,
state::ScoreGradELBOState,
params,
restructure,
state,
)
q = restructure(params)
(; obj_ad_prep, problem) = state
Expand Down
Loading