Skip to content

Commit

Permalink
refactor optimize warm-starting interface, add objargs argument
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Aug 24, 2023
1 parent f593a67 commit ff32ac6
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 38 deletions.
64 changes: 33 additions & 31 deletions src/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ end
objective ::AbstractVariationalObjective,
restructure,
λ₀ ::AbstractVector{<:Real},
n_max_iter ::Int;
n_max_iter ::Int,
objargs...;
kwargs...
)
Expand All @@ -17,7 +18,8 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
optimize(
objective ::AbstractVariationalObjective,
q,
n_max_iter::Int;
n_max_iter::Int,
objargs...;
kwargs...
)
Expand All @@ -29,83 +31,83 @@ Optimize the variational objective `objective` by estimating (stochastic) gradie
- `restruct`: Function that reconstructs the variational approximation from the flattened parameters.
- `q`: Initial variational approximation. The variational parameters must be extractable through `Optimisers.destructure`.
- `n_max_iter`: Maximum number of iterations.
- `objargs...`: Arguments to be passed to `objective`.
- `kwargs...`: Additional keywoard arguments. (See below.)
# Keyword Arguments
- `adbackend`: Automatic differentiation backend. (Type: `<: ADtypes.AbstractADType`.)
- `optimizer`: Optimizer used for inference. (Type: `<: Optimisers.AbstractRule`; Default: `Adam`.)
- `rng`: Random number generator. (Type: `<: AbstractRNG`; Default: `Random.default_rng()`.)
- `show_progress`: Whether to show the progress bar. (Type: `<: Bool`; Default: `true`.)
- `callback!`: Callback function called after every iteration. The signature is `cb(; obj_state, stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`. If the estimator associated with `objective` is stateful, `obj_state` contains its state. (Default: `nothing`.) `g` is the stochastic gradient.
- `callback!`: Callback function called after every iteration. The signature is `cb(; stats, restructure, λ, g)`, which returns a dictionary-like object containing statistics to be displayed on the progress bar. The variational approximation can be reconstructed as `restructure(λ)`, `g` is the stochastic estimate of the gradient. (Default: `nothing`.)
- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.)
When resuming from the state of a previous run, use the following keyword arguments:
- `opt_state`: Initial state of the optimizer.
- `obj_state`: Initial state of the objective.
- `state`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) (Type: `<: NamedTuple`.)
# Returns
- `λ`: Variational parameters optimizing the variational objective.
- `stats`: Statistics gathered during inference.
- `opt_state`: Final state of the optimiser.
- `obj_state`: Final state of the objective.
- `logstats`: Statistics and logs gathered during optimization.
- `states`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run.
"""
function optimize(
objective ::AbstractVariationalObjective,
restructure,
λ₀ ::AbstractVector{<:Real},
n_max_iter ::Int;
adbackend::AbstractADType,
n_max_iter ::Int,
objargs...;
adbackend ::AbstractADType,
optimizer ::Optimisers.AbstractRule = Optimisers.Adam(),
rng ::AbstractRNG = default_rng(),
show_progress::Bool = true,
opt_state = nothing,
obj_state = nothing,
state ::NamedTuple = NamedTuple(),
callback! = nothing,
prog = ProgressMeter.Progress(
n_max_iter;
desc = "Optimizing",
barlen = 31,
showspeed = true,
enabled = show_progress
)
)
)
λ = copy(λ₀)
opt_state = isnothing(opt_state) ? Optimisers.setup(optimizer, λ) : opt_state
obj_state = isnothing(obj_state) ? init(rng, objective, λ, restructure) : obj_state
grad_buf = DiffResults.GradientResult(λ)
stats = NamedTuple[]
λ = copy(λ₀)
opt_st = haskey(state, :opt) ? state.opt : Optimisers.setup(optimizer, λ)
obj_st = haskey(state, :obj) ? state.obj : init(rng, objective, λ, restructure)
grad_buf = DiffResults.GradientResult(λ)
logstats = NamedTuple[]

for t = 1:n_max_iter
stat = (iteration=t,)

grad_buf, obj_state, stat′ = estimate_gradient(
rng, adbackend, objective, obj_state, λ, restructure, grad_buf)
grad_buf, obj_st, stat′ = estimate_gradient(
rng, adbackend, objective, obj_st,
λ, restructure, grad_buf; objargs...
)
stat = merge(stat, stat′)

g = DiffResults.gradient(grad_buf)
opt_state, λ = Optimisers.update!(opt_state, λ, g)
stat′ = (iteration = t,)
stat = merge(stat, stat′)
g = DiffResults.gradient(grad_buf)
opt_st, λ = Optimisers.update!(opt_st, λ, g)

if !isnothing(callback!)
stat′ = callback!(; obj_state, stat, restructure, λ, g)
stat′ = callback!(; stat, restructure, λ, g)
stat = !isnothing(stat′) ? merge(stat′, stat) : stat
end

@debug "Iteration $t" stat...

pm_next!(prog, stat)
push!(stats, stat)
push!(logstats, stat)
end
λ, map(identity, stats), opt_state, obj_state
state = (opt=opt_st, obj=obj_st)
logstats = map(identity, logstats)
λ, logstats, state
end

function optimize(objective ::AbstractVariationalObjective,
q₀,
n_max_iter::Int;
kwargs...)
λ, restructure = Optimisers.destructure(q₀)
λ, stats, opt_state, obj_state = optimize(
λ, logstats, state = optimize(
objective, restructure, λ, n_max_iter; kwargs...
)
restructure(λ), stats, opt_state, obj_state
restructure(λ), logstats, state
end
6 changes: 3 additions & 3 deletions test/advi_locscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ using ReTest

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
q, stats, _, _ = optimize(
q, stats, _ = optimize(
obj, q₀, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
Expand All @@ -69,7 +69,7 @@ using ReTest

@testset "determinism" begin
rng = Philox4x(UInt64, seed, 8)
q, stats, _, _ = optimize(
q, stats, _ = optimize(
obj, q₀, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
Expand All @@ -80,7 +80,7 @@ using ReTest
L = q.scale

rng_repl = Philox4x(UInt64, seed, 8)
q, stats, _, _ = optimize(
q, stats, _ = optimize(
obj, q₀, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
Expand Down
34 changes: 30 additions & 4 deletions test/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using ReTest
optimizer = Optimisers.Adam(1e-2)

rng = Philox4x(UInt64, seed, 8)
q_ref, stats_ref, _, _ = optimize(
q_ref, stats_ref, _ = optimize(
obj, q₀, T;
optimizer,
show_progress = false,
Expand All @@ -34,7 +34,7 @@ using ReTest
λ₀, re = Optimisers.destructure(q₀)

rng = Philox4x(UInt64, seed, 8)
λ, stats, _, _ = optimize(
λ, stats, _ = optimize(
obj, re, λ₀, T;
optimizer,
show_progress = false,
Expand All @@ -49,18 +49,44 @@ using ReTest
rng = Philox4x(UInt64, seed, 8)
test_values = rand(rng, T)

callback!(; stat, obj_state, restructure, λ, g) = begin
callback!(; stat, restructure, λ, g) = begin
(test_value = test_values[stat.iteration],)
end

rng = Philox4x(UInt64, seed, 8)
_, stats, _, _ = optimize(
_, stats, _ = optimize(
obj, q₀, T;
optimizer,
show_progress = false,
rng,
adbackend,
callback!
)
@test [stat.test_value for stat stats] == test_values
end

@testset "warm start" begin
rng = Philox4x(UInt64, seed, 8)

T_first = div(T,2)
T_last = T - T_first

q_first, _, state = optimize(
obj, q₀, T_first;
optimizer,
show_progress = false,
rng,
adbackend
)

q, stats, _ = optimize(
obj, q_first, T_last;
optimizer,
show_progress = false,
state,
rng,
adbackend
)
@test q == q_ref
end
end

0 comments on commit ff32ac6

Please sign in to comment.