Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions test/algorithms/klminrepgraddescent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@

@testset "KLMinRepGradDescent" begin
begin
modelstats = normal_meanfield(Random.default_rng(), Float64)
(; model, n_dims, μ_true, L_true) = modelstats

q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))

@testset "basic n_samples=$(n_samples)" for n_samples in [1, 10]
alg = KLMinRepGradDescent(AD; n_samples, operator=ClipScale())
T = 1
optimize(alg, T, model, q0; show_progress=PROGRESS)
end

@testset "callback" begin
alg = KLMinRepGradDescent(AD; operator=ClipScale())
T = 10
callback(; iteration, kwargs...) = (iteration_check=iteration,)
_, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS)
@test [i.iteration_check for i in info] == 1:T
end

@testset "estimate_objective" begin
alg = KLMinRepGradDescent(AD; operator=ClipScale())
q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))

obj_est = estimate_objective(alg, q_true, model)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=1)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=3)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=10^5)
@test obj_est ≈ 0 atol=1e-2
end

@testset "determinism" begin
alg = KLMinRepGradDescent(AD; operator=ClipScale())

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
T = 10

q_out, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS)
μ = q_out.location
L = q_out.scale

rng_repl = StableRNG(seed)
q_out, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS)
μ_repl = q_out.location
L_repl = q_out.scale
@test μ == μ_repl
@test L == L_repl
end

@testset "warn MvLocationScale with IdentityOperator" begin
@test_warn "IdentityOperator" begin
alg′ = KLMinRepGradDescent(AD; operator=IdentityOperator())
optimize(alg′, 1, model, q0; show_progress=false)
end
end

@testset "STL variance reduction" begin
@testset for n_montecarlo in [1, 10]
q_true = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))
params, re = Optimisers.destructure(q_true)
obj = RepGradELBO(n_montecarlo; entropy=StickingTheLandingEntropy())
out = DiffResults.DiffResult(zero(eltype(params)), similar(params))

aux = (
rng=Random.default_rng(),
obj=obj,
problem=model,
restructure=re,
q_stop=q_true,
adtype=AD,
)
AdvancedVI._value_and_gradient!(
AdvancedVI.estimate_repgradelbo_ad_forward, out, AD, params, aux
)
grad = DiffResults.gradient(out)
@test norm(grad) ≈ 0 atol = 1e-5
end
end
end

@testset "type stability realtype=$(realtype)" for realtype in [Float32, Float64]
modelstats = normal_meanfield(Random.default_rng(), realtype)
(; model, n_dims, μ_true, L_true) = modelstats

T = 1
alg = KLMinRepGradDescent(AD; n_samples=10, operator=ClipScale())
q0 = MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))

q_out, info, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)

@test eltype(q_out.location) == realtype
@test eltype(q_out.scale) == realtype
@test typeof(first(info).elbo) == realtype
end

@testset "convergence $(entropy)" for entropy in
[ClosedFormEntropy(), StickingTheLandingEntropy()]
modelstats = normal_meanfield(Random.default_rng(), Float64)
(; model, μ_true, L_true, is_meanfield) = modelstats

T = 1000
optimizer = Descent(1e-3)
alg = KLMinRepGradDescent(AD; entropy, optimizer, operator=ClipScale())
q0 = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))

q_out, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)

Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ = sum(abs2, q_out.location - μ_true) + sum(abs2, q_out.scale - L_true)

@test Δλ ≤ Δλ0/2
end

@testset "subsampling" begin
n_data = 8

@testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4]
modelstats = subsamplednormal(Random.default_rng(), n_data)
(; model, n_dims, μ_true, L_true) = modelstats

L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)
operator = ClipScale()

subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
alg = KLMinRepGradDescent(AD; n_samples=10, operator)
alg_sub = KLMinRepGradDescent(AD; n_samples=10, subsampling, operator)

obj_full = estimate_objective(alg, q0, model; n_samples=10^5)
obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5)
@test obj_full ≈ obj_sub rtol=0.1
end

@testset "determinism" begin
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = subsamplednormal(Random.default_rng(), n_data)
(; model, n_dims, μ_true, L_true) = modelstats

L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)

T = 10
batchsize = 3
subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
alg_sub = KLMinRepGradDescent(
AD; n_samples=10, subsampling, operator=ClipScale()
)

q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS)
μ = q.location
L = q.scale

rng_repl = StableRNG(seed)
q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS)
μ_repl = q.location
L_repl = q.scale
@test μ == μ_repl
@test L == L_repl
end

@testset "convergence" begin
modelstats = subsamplednormal(Random.default_rng(), n_data)
(; model, n_dims, μ_true, L_true) = modelstats

L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims))
q0 = FullRankGaussian(zeros(Float64, n_dims), L0)

T = 1000
batchsize = 1
optimizer = Descent(1e-3)
subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize)
alg_sub = KLMinRepGradDescent(
AD; n_samples=10, subsampling, optimizer, operator=ClipScale()
)

q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS)

Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true)

@test Δλ ≤ Δλ0/2
end
end
end
94 changes: 94 additions & 0 deletions test/algorithms/klminrepgraddescent_bijectors.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

@testset "KLMinRepGradDescent with Bijectors" begin
begin
modelstats = normallognormal_meanfield(Random.default_rng(), Float64)
(; model, n_dims, μ_true, L_true) = modelstats

b = Bijectors.bijector(model)
binv = inverse(b)

q0_unconstr = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
q0 = Bijectors.transformed(q0_unconstr, binv)

@testset "estimate_objective" begin
alg = KLMinRepGradDescent(AD; operator=ClipScale())
q_true_unconstr = MeanFieldGaussian(Vector(μ_true), Diagonal(L_true))
q_true = Bijectors.transformed(q_true_unconstr, binv)

obj_est = estimate_objective(alg, q_true, model)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=1)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=3)
@test isfinite(obj_est)

obj_est = estimate_objective(alg, q_true, model; n_samples=10^5)
@test obj_est ≈ 0 atol=1e-2
end

@testset "determinism" begin
alg = KLMinRepGradDescent(AD; operator=ClipScale())

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)
T = 10

q_out, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS)
μ = q_out.dist.location
L = q_out.dist.scale

rng_repl = StableRNG(seed)
q_out, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS)
μ_repl = q_out.dist.location
L_repl = q_out.dist.scale
@test μ == μ_repl
@test L == L_repl
end

@testset "warn MvLocationScale with IdentityOperator" begin
@test_warn "IdentityOperator" begin
alg′ = KLMinRepGradDescent(AD; operator=IdentityOperator())
optimize(alg′, 1, model, q0; show_progress=false)
end
end
end

@testset "type stability realtype=$(realtype)" for realtype in [Float32, Float64]
modelstats = normallognormal_meanfield(Random.default_rng(), realtype)
(; model, n_dims, μ_true, L_true) = modelstats

T = 1
alg = KLMinRepGradDescent(AD; n_samples=10, operator=ClipScale())
q0_unconstr = MeanFieldGaussian(
zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))
)
q0 = Bijectors.transformed(q0_unconstr, binv)

q_out, info, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)

@test eltype(q_out.dist.location) == realtype
@test eltype(q_out.dist.scale) == realtype
@test typeof(first(info).elbo) == realtype
end

@testset "convergence $(entropy)" for entropy in
[ClosedFormEntropy(), StickingTheLandingEntropy()]
modelstats = normallognormal_meanfield(Random.default_rng(), Float64)
(; model, μ_true, L_true, is_meanfield) = modelstats

T = 1000
optimizer = Descent(1e-3)
alg = KLMinRepGradDescent(AD; entropy, optimizer, operator=ClipScale())
q0_unconstr = MeanFieldGaussian(zeros(n_dims), Diagonal(ones(n_dims)))
q0 = Bijectors.transformed(q0_unconstr, binv)

q_out, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS)

Δλ0 = sum(abs2, q0.dist.location - μ_true) + sum(abs2, q0.dist.scale - L_true)
Δλ = sum(abs2, q_out.dist.location - μ_true) + sum(abs2, q_out.dist.scale - L_true)

@test Δλ ≤ Δλ0/2
end
end
Loading
Loading