diff --git a/docs/src/general.md b/docs/src/general.md index c10c5b746..417d6b225 100644 --- a/docs/src/general.md +++ b/docs/src/general.md @@ -38,7 +38,7 @@ The operation of `optimize` can be simplified as follows: ```julia function optimize([rng,] algorithm, max_iter, q_init, objargs; kwargs...) info_total = NamedTuple[] - state = init(rng, algorithm, q_init) + state = init(rng, algorithm, q_init, prob) for t in 1:max_iter info = (iteration=t,) state, terminate, info′ = step( diff --git a/docs/src/paramspacesgd/general.md b/docs/src/paramspacesgd/general.md index 614a73822..347a9dc82 100644 --- a/docs/src/paramspacesgd/general.md +++ b/docs/src/paramspacesgd/general.md @@ -77,6 +77,7 @@ AdvancedVI.init( ::Any, ::Any, ::Any, + ::Any, ) ``` If this method is not implemented, the state will be automatically be `nothing`. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index b9b3e979b..58826440e 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -199,15 +199,15 @@ Abstract type for a variational inference algorithm. abstract type AbstractAlgorithm end """ - init(rng, alg, prob, q_init) + init(rng, alg, q_init, prob) Initialize `alg` given the initial variational approximation `q_init` and the target `prob`. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `alg::AbstractAlgorithm`: Variational inference algorithm. -- `prob`: Target problem. - `q_init`: Initial variational approximation. +- `prob`: Target problem. # Returns - `state`: Initial state of the algorithm. diff --git a/src/algorithms/paramspacesgd/abstractobjective.jl b/src/algorithms/paramspacesgd/abstractobjective.jl index 1680ecb1d..d125c3be8 100644 --- a/src/algorithms/paramspacesgd/abstractobjective.jl +++ b/src/algorithms/paramspacesgd/abstractobjective.jl @@ -13,17 +13,19 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, adtype, prob, params, restructure) + init(rng, obj, adtype, q_init, prob, params, restructure) -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +Initialize a state of the variational objective `obj` given the initial variational approximation `q_init` and its parameters `params`. This function needs to be implemented only if `obj` is stateful. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. ` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend. +- `q_init`: Initial variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Initial variational parameters. -- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `restructure`: Function that reconstructs the variational approximation from `params`. """ function init( ::Random.AbstractRNG, @@ -32,6 +34,7 @@ function init( ::Any, ::Any, ::Any, + ::Any, ) nothing end @@ -69,7 +72,7 @@ function set_objective_state_problem end """ estimate_gradient!(rng, obj, adtype, out, obj_state, params, restructure) -Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` +Estimate (possibly stochastic) gradients of the variational objective `obj` with respect to the variational parameters `params` # Arguments - `rng::Random.AbstractRNG`: Random number generator. diff --git a/src/algorithms/paramspacesgd/paramspacesgd.jl b/src/algorithms/paramspacesgd/paramspacesgd.jl index 5d01c10ca..8e2144b21 100644 --- a/src/algorithms/paramspacesgd/paramspacesgd.jl +++ b/src/algorithms/paramspacesgd/paramspacesgd.jl @@ -64,11 +64,11 @@ struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt} avg_st::AvgSt end -function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, prob, q_init) +function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob) (; adtype, optimizer, averager, objective) = alg params, re = Optimisers.destructure(q_init) opt_st = Optimisers.setup(optimizer, params) - obj_st = init(rng, objective, adtype, prob, params, re) + obj_st = init(rng, objective, adtype, q_init, prob, params, re) avg_st = init(averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st) diff --git a/src/algorithms/paramspacesgd/repgradelbo.jl b/src/algorithms/paramspacesgd/repgradelbo.jl index 2d0c2c7b2..868dbc594 100644 --- a/src/algorithms/paramspacesgd/repgradelbo.jl +++ b/src/algorithms/paramspacesgd/repgradelbo.jl @@ -42,12 +42,13 @@ function init( rng::Random.AbstractRNG, obj::RepGradELBO, adtype::ADTypes.AbstractADType, - prob::Prob, + q, + prob, params, restructure, -) where {Prob} - q_stop = restructure(params) - capability = LogDensityProblems.capabilities(Prob) +) + q_stop = q + capability = LogDensityProblems.capabilities(typeof(prob)) ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}() @info "The capability of the supplied `LogDensityProblem` $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}()). `AdvancedVI` will attempt to directly differentiate through `LogDensityProblems.logdensity`. If this is not intended, please supply a log-density problem with capability at least $(LogDensityProblems.LogDensityOrder{1}())" prob diff --git a/src/algorithms/paramspacesgd/scoregradelbo.jl b/src/algorithms/paramspacesgd/scoregradelbo.jl index 3213693d5..a00d532c9 100644 --- a/src/algorithms/paramspacesgd/scoregradelbo.jl +++ b/src/algorithms/paramspacesgd/scoregradelbo.jl @@ -35,12 +35,12 @@ function init( rng::Random.AbstractRNG, obj::ScoreGradELBO, adtype::ADTypes.AbstractADType, + q_init, prob, params, restructure, ) - q = restructure(params) - samples = rand(rng, q, obj.n_samples) + samples = rand(rng, q_init, obj.n_samples) ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure) obj_ad_prep = AdvancedVI._prepare_gradient( diff --git a/src/algorithms/paramspacesgd/subsampledobjective.jl b/src/algorithms/paramspacesgd/subsampledobjective.jl index b3905dd3e..a904eb245 100644 --- a/src/algorithms/paramspacesgd/subsampledobjective.jl +++ b/src/algorithms/paramspacesgd/subsampledobjective.jl @@ -76,13 +76,14 @@ function init( rng::Random.AbstractRNG, subobj::SubsampledObjective, adtype::ADTypes.AbstractADType, + q_init, prob, params, restructure, ) (; objective, subsampling) = subobj sub_st = init(rng, subsampling) - obj_st = AdvancedVI.init(rng, objective, adtype, prob, params, restructure) + obj_st = AdvancedVI.init(rng, objective, adtype, q_init, prob, params, restructure) return SubsampledObjectiveState(prob, sub_st, obj_st) end diff --git a/src/optimize.jl b/src/optimize.jl index e538a0838..19bdc5257 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -56,7 +56,7 @@ function optimize( ) info_total = NamedTuple[] state = if isnothing(state) - init(rng, algorithm, prob, q_init) + init(rng, algorithm, q_init, prob) else state end diff --git a/test/algorithms/paramspacesgd/subsampledobj.jl b/test/algorithms/paramspacesgd/subsampledobj.jl index dec42148b..f7e81d55f 100644 --- a/test/algorithms/paramspacesgd/subsampledobj.jl +++ b/test/algorithms/paramspacesgd/subsampledobj.jl @@ -98,7 +98,7 @@ end # Estimate using full batch rng = StableRNG(seed) - full_state = AdvancedVI.init(rng, full_obj, AD, prob, params, restructure) + full_state = AdvancedVI.init(rng, full_obj, AD, q0, prob, params, restructure) AdvancedVI.estimate_gradient!( rng, full_obj, AD, out, full_state, params, restructure ) @@ -106,7 +106,7 @@ end # Estimate the full batch gradient by averaging the minibatch gradients rng = StableRNG(seed) - sub_state = AdvancedVI.init(rng, sub_obj, AD, prob, params, restructure) + sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, prob, params, restructure) grad = mean(1:length(sub_obj.subsampling)) do _ # Fixing the RNG so that the same Monte Carlo samples are used across the batches rng = StableRNG(seed)