Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The operation of `optimize` can be simplified as follows:
```julia
function optimize([rng,] algorithm, max_iter, q_init, objargs; kwargs...)
info_total = NamedTuple[]
state = init(rng, algorithm, q_init)
state = init(rng, algorithm, q_init, prob)
for t in 1:max_iter
info = (iteration=t,)
state, terminate, info′ = step(
Expand Down
1 change: 1 addition & 0 deletions docs/src/paramspacesgd/general.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ AdvancedVI.init(
::Any,
::Any,
::Any,
::Any,
)
```
If this method is not implemented, the state will be automatically be `nothing`.
Expand Down
4 changes: 2 additions & 2 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,15 @@ Abstract type for a variational inference algorithm.
abstract type AbstractAlgorithm end

"""
init(rng, alg, prob, q_init)
init(rng, alg, q_init, prob)

Initialize `alg` given the initial variational approximation `q_init` and the target `prob`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::AbstractAlgorithm`: Variational inference algorithm.
- `prob`: Target problem.
- `q_init`: Initial variational approximation.
- `prob`: Target problem.

# Returns
- `state`: Initial state of the algorithm.
Expand Down
11 changes: 7 additions & 4 deletions src/algorithms/paramspacesgd/abstractobjective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,19 @@ If the estimator is stateful, it can implement `init` to initialize the state.
abstract type AbstractVariationalObjective end

"""
init(rng, obj, adtype, prob, params, restructure)
init(rng, obj, adtype, q_init, prob, params, restructure)

Initialize a state of the variational objective `obj` given the initial variational parameters `λ`.
Initialize a state of the variational objective `obj` given the initial variational approximation `q_init` and its parameters `params`.
This function needs to be implemented only if `obj` is stateful.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `obj::AbstractVariationalObjective`: Variational objective.
` `adtype::ADTypes.AbstractADType`: Automatic differentiation backend.
- `q_init`: Initial variational approximation.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
- `params`: Initial variational parameters.
- `restructure`: Function that reconstructs the variational approximation from `λ`.
- `restructure`: Function that reconstructs the variational approximation from `params`.
"""
function init(
::Random.AbstractRNG,
Expand All @@ -32,6 +34,7 @@ function init(
::Any,
::Any,
::Any,
::Any,
)
nothing
end
Expand Down Expand Up @@ -69,7 +72,7 @@ function set_objective_state_problem end
"""
estimate_gradient!(rng, obj, adtype, out, obj_state, params, restructure)

Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ`
Estimate (possibly stochastic) gradients of the variational objective `obj` with respect to the variational parameters `params`

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/paramspacesgd/paramspacesgd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ struct ParamSpaceSGDState{P,Q,GradBuf,OptSt,ObjSt,AvgSt}
avg_st::AvgSt
end

function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, prob, q_init)
function init(rng::Random.AbstractRNG, alg::ParamSpaceSGD, q_init, prob)
(; adtype, optimizer, averager, objective) = alg
params, re = Optimisers.destructure(q_init)
opt_st = Optimisers.setup(optimizer, params)
obj_st = init(rng, objective, adtype, prob, params, re)
obj_st = init(rng, objective, adtype, q_init, prob, params, re)
avg_st = init(averager, params)
grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params))
return ParamSpaceSGDState(prob, q_init, 0, grad_buf, opt_st, obj_st, avg_st)
Expand Down
9 changes: 5 additions & 4 deletions src/algorithms/paramspacesgd/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ function init(
rng::Random.AbstractRNG,
obj::RepGradELBO,
adtype::ADTypes.AbstractADType,
prob::Prob,
q,
prob,
params,
restructure,
) where {Prob}
q_stop = restructure(params)
capability = LogDensityProblems.capabilities(Prob)
)
q_stop = q
capability = LogDensityProblems.capabilities(typeof(prob))
ad_prob = if capability < LogDensityProblems.LogDensityOrder{1}()
@info "The capability of the supplied `LogDensityProblem` $(capability) is less than $(LogDensityProblems.LogDensityOrder{1}()). `AdvancedVI` will attempt to directly differentiate through `LogDensityProblems.logdensity`. If this is not intended, please supply a log-density problem with capability at least $(LogDensityProblems.LogDensityOrder{1}())"
prob
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/paramspacesgd/scoregradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ function init(
rng::Random.AbstractRNG,
obj::ScoreGradELBO,
adtype::ADTypes.AbstractADType,
q_init,
prob,
params,
restructure,
)
q = restructure(params)
samples = rand(rng, q, obj.n_samples)
samples = rand(rng, q_init, obj.n_samples)
ℓπ = map(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
aux = (adtype=adtype, logprob_stop=ℓπ, samples_stop=samples, restructure=restructure)
obj_ad_prep = AdvancedVI._prepare_gradient(
Expand Down
3 changes: 2 additions & 1 deletion src/algorithms/paramspacesgd/subsampledobjective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,14 @@ function init(
rng::Random.AbstractRNG,
subobj::SubsampledObjective,
adtype::ADTypes.AbstractADType,
q_init,
prob,
params,
restructure,
)
(; objective, subsampling) = subobj
sub_st = init(rng, subsampling)
obj_st = AdvancedVI.init(rng, objective, adtype, prob, params, restructure)
obj_st = AdvancedVI.init(rng, objective, adtype, q_init, prob, params, restructure)
return SubsampledObjectiveState(prob, sub_st, obj_st)
end

Expand Down
2 changes: 1 addition & 1 deletion src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function optimize(
)
info_total = NamedTuple[]
state = if isnothing(state)
init(rng, algorithm, prob, q_init)
init(rng, algorithm, q_init, prob)
else
state
end
Expand Down
4 changes: 2 additions & 2 deletions test/algorithms/paramspacesgd/subsampledobj.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ end

# Estimate using full batch
rng = StableRNG(seed)
full_state = AdvancedVI.init(rng, full_obj, AD, prob, params, restructure)
full_state = AdvancedVI.init(rng, full_obj, AD, q0, prob, params, restructure)
AdvancedVI.estimate_gradient!(
rng, full_obj, AD, out, full_state, params, restructure
)
grad_ref = DiffResults.gradient(out)

# Estimate the full batch gradient by averaging the minibatch gradients
rng = StableRNG(seed)
sub_state = AdvancedVI.init(rng, sub_obj, AD, prob, params, restructure)
sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, prob, params, restructure)
grad = mean(1:length(sub_obj.subsampling)) do _
# Fixing the RNG so that the same Monte Carlo samples are used across the batches
rng = StableRNG(seed)
Expand Down
Loading