diff --git a/test/algorithms/subsampledobj.jl b/test/algorithms/subsampledobj.jl index c5b8720ec..c2907e62c 100644 --- a/test/algorithms/subsampledobj.jl +++ b/test/algorithms/subsampledobj.jl @@ -1,44 +1,12 @@ -struct SubsampledNormals{D<:Normal,F<:Real} - dists::Vector{D} - likeadj::F -end - -function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) - μs = randn(rng, n_normals) - σs = ones(n_normals) - dists = Normal.(μs, σs) - SubsampledNormals{eltype(dists),Float64}(dists, 1.0) -end - -function LogDensityProblems.logdensity(m::SubsampledNormals, x) - (; likeadj, dists) = m - likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists) -end - -function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x) - return ( - LogDensityProblems.logdensity(m, x), - ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), - ) -end - -function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) - return LogDensityProblems.LogDensityOrder{1}() -end - -function AdvancedVI.subsample(m::SubsampledNormals, idx) - n_data = length(m.dists) - SubsampledNormals(m.dists[idx], n_data/length(idx)) -end - @testset "SubsampledObjective" begin seed = (0x38bef07cf9cc549d) n_data = 8 - prob = SubsampledNormals(Random.default_rng(), n_data) - μ0 = [mean([mean(dist) for dist in prob.dists])] - q0 = MeanFieldGaussian(μ0, Diagonal(ones(1))) + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + q0 = MeanFieldGaussian(μ_true, Diagonal(diag(L_true))) full_obj = RepGradELBO(10) @testset "algorithm constructors" begin @@ -47,17 +15,17 @@ end alg = KLMinRepGradDescent( AD; n_samples=10, subsampling=sub, operator=ClipScale() ) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) alg = KLMinRepGradProxDescent(AD; n_samples=10, subsampling=sub) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) alg = KLMinScoreGradDescent( AD; n_samples=100, subsampling=sub, operator=ClipScale() ) - _, info, _ = optimize(alg, 10, prob, q0; show_progress=false) + _, info, _ = optimize(alg, 10, model, q0; show_progress=false) @test isfinite(last(info).elbo) end end @@ -69,25 +37,25 @@ end sub_obj = alg.objective rng = StableRNG(seed) - q_avg, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false) + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) rng = StableRNG(seed) - q_avg_ref, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false) + q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false) @test q_avg == q_avg_ref rng = StableRNG(seed) - sub_objval_ref = estimate_objective(rng, sub_obj, q0, prob) + sub_objval_ref = estimate_objective(rng, sub_obj, q0, model) rng = StableRNG(seed) - sub_objval = estimate_objective(rng, sub_obj, q0, prob) + sub_objval = estimate_objective(rng, sub_obj, q0, model) @test sub_objval == sub_objval_ref end @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] sub = ReshufflingBatchSubsampling(1:n_data, batchsize) sub_obj′ = SubsampledObjective(full_obj, sub) - full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^8) - sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^8) + full_objval = estimate_objective(full_obj, q0, model; n_samples=10^8) + sub_objval = estimate_objective(sub_obj′, q0, model; n_samples=10^8) @test full_objval ≈ sub_objval rtol=0.1 end @@ -100,7 +68,7 @@ end # Estimate using full batch rng = StableRNG(seed) - full_state = AdvancedVI.init(rng, full_obj, AD, q0, prob, params, restructure) + full_state = AdvancedVI.init(rng, full_obj, AD, q0, model, params, restructure) AdvancedVI.estimate_gradient!( rng, full_obj, AD, out, full_state, params, restructure ) @@ -108,7 +76,7 @@ end # Estimate the full batch gradient by averaging the minibatch gradients rng = StableRNG(seed) - sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, prob, params, restructure) + sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, model, params, restructure) grad = mean(1:length(sub_obj.subsampling)) do _ # Fixing the RNG so that the same Monte Carlo samples are used across the batches rng = StableRNG(seed) diff --git a/test/models/subsamplednormals.jl b/test/models/subsamplednormals.jl new file mode 100644 index 000000000..86be0dc1a --- /dev/null +++ b/test/models/subsamplednormals.jl @@ -0,0 +1,49 @@ + +struct SubsampledNormals{D<:Normal,F<:Real} + dists::Vector{D} + likeadj::F +end + +function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int) + μs = randn(rng, n_normals) + σs = ones(n_normals) + dists = Normal.(μs, σs) + return SubsampledNormals{eltype(dists),Float64}(dists, 1.0) +end + +function LogDensityProblems.logdensity(m::SubsampledNormals, x) + (; likeadj, dists) = m + return likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists) +end + +function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x) + return ( + LogDensityProblems.logdensity(m, x), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), + ) +end + +function LogDensityProblems.logdensity_gradient_and_hessian(m::SubsampledNormals, x) + return ( + LogDensityProblems.logdensity(m, x), + ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x), + ForwardDiff.hessian(Base.Fix1(LogDensityProblems.logdensity, m), x), + ) +end + +function LogDensityProblems.capabilities(::Type{<:SubsampledNormals}) + return LogDensityProblems.LogDensityOrder{2}() +end + +function AdvancedVI.subsample(m::SubsampledNormals, idx) + n_data = length(m.dists) + return SubsampledNormals(m.dists[idx], n_data/length(idx)) +end + +function subsamplednormal(rng::Random.AbstractRNG, n_data::Int) + model = SubsampledNormals(rng, n_data) + n_dims = 1 + μ_true = [mean([mean(dist) for dist in model.dists])] + L_true = Diagonal(ones(1)) + return TestModel(model, μ_true, L_true, n_dims, 1, true) +end diff --git a/test/runtests.jl b/test/runtests.jl index 2dae5b31c..47bf0846b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,9 +53,9 @@ struct TestModel{M,L,S,SC} end include("models/normal.jl") include("models/normallognormal.jl") +include("models/subsamplednormals.jl") if GROUP == "All" || GROUP == "GENERAL" - # Tests that do not need to check correct integration with AD backends include("general/optimize.jl") include("general/proximal_location_scale_entropy.jl") include("general/rules.jl")