Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 15 additions & 47 deletions test/algorithms/subsampledobj.jl
Original file line number Diff line number Diff line change
@@ -1,44 +1,12 @@

struct SubsampledNormals{D<:Normal,F<:Real}
dists::Vector{D}
likeadj::F
end

function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int)
μs = randn(rng, n_normals)
σs = ones(n_normals)
dists = Normal.(μs, σs)
SubsampledNormals{eltype(dists),Float64}(dists, 1.0)
end

function LogDensityProblems.logdensity(m::SubsampledNormals, x)
(; likeadj, dists) = m
likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists)
end

function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x)
return (
LogDensityProblems.logdensity(m, x),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x),
)
end

function LogDensityProblems.capabilities(::Type{<:SubsampledNormals})
return LogDensityProblems.LogDensityOrder{1}()
end

function AdvancedVI.subsample(m::SubsampledNormals, idx)
n_data = length(m.dists)
SubsampledNormals(m.dists[idx], n_data/length(idx))
end

@testset "SubsampledObjective" begin
seed = (0x38bef07cf9cc549d)
n_data = 8
prob = SubsampledNormals(Random.default_rng(), n_data)

μ0 = [mean([mean(dist) for dist in prob.dists])]
q0 = MeanFieldGaussian(μ0, Diagonal(ones(1)))
modelstats = subsamplednormal(Random.default_rng(), n_data)
(; model, n_dims, μ_true, L_true) = modelstats

q0 = MeanFieldGaussian(μ_true, Diagonal(diag(L_true)))
full_obj = RepGradELBO(10)

@testset "algorithm constructors" begin
Expand All @@ -47,17 +15,17 @@ end
alg = KLMinRepGradDescent(
AD; n_samples=10, subsampling=sub, operator=ClipScale()
)
_, info, _ = optimize(alg, 10, prob, q0; show_progress=false)
_, info, _ = optimize(alg, 10, model, q0; show_progress=false)
@test isfinite(last(info).elbo)

alg = KLMinRepGradProxDescent(AD; n_samples=10, subsampling=sub)
_, info, _ = optimize(alg, 10, prob, q0; show_progress=false)
_, info, _ = optimize(alg, 10, model, q0; show_progress=false)
@test isfinite(last(info).elbo)

alg = KLMinScoreGradDescent(
AD; n_samples=100, subsampling=sub, operator=ClipScale()
)
_, info, _ = optimize(alg, 10, prob, q0; show_progress=false)
_, info, _ = optimize(alg, 10, model, q0; show_progress=false)
@test isfinite(last(info).elbo)
end
end
Expand All @@ -69,25 +37,25 @@ end
sub_obj = alg.objective

rng = StableRNG(seed)
q_avg, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false)
q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=false)

rng = StableRNG(seed)
q_avg_ref, _, _ = optimize(rng, alg, T, prob, q0; show_progress=false)
q_avg_ref, _, _ = optimize(rng, alg, T, model, q0; show_progress=false)
@test q_avg == q_avg_ref

rng = StableRNG(seed)
sub_objval_ref = estimate_objective(rng, sub_obj, q0, prob)
sub_objval_ref = estimate_objective(rng, sub_obj, q0, model)

rng = StableRNG(seed)
sub_objval = estimate_objective(rng, sub_obj, q0, prob)
sub_objval = estimate_objective(rng, sub_obj, q0, model)
@test sub_objval == sub_objval_ref
end

@testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4]
sub = ReshufflingBatchSubsampling(1:n_data, batchsize)
sub_obj′ = SubsampledObjective(full_obj, sub)
full_objval = estimate_objective(full_obj, q0, prob; n_samples=10^8)
sub_objval = estimate_objective(sub_obj′, q0, prob; n_samples=10^8)
full_objval = estimate_objective(full_obj, q0, model; n_samples=10^8)
sub_objval = estimate_objective(sub_obj′, q0, model; n_samples=10^8)
@test full_objval ≈ sub_objval rtol=0.1
end

Expand All @@ -100,15 +68,15 @@ end

# Estimate using full batch
rng = StableRNG(seed)
full_state = AdvancedVI.init(rng, full_obj, AD, q0, prob, params, restructure)
full_state = AdvancedVI.init(rng, full_obj, AD, q0, model, params, restructure)
AdvancedVI.estimate_gradient!(
rng, full_obj, AD, out, full_state, params, restructure
)
grad_ref = DiffResults.gradient(out)

# Estimate the full batch gradient by averaging the minibatch gradients
rng = StableRNG(seed)
sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, prob, params, restructure)
sub_state = AdvancedVI.init(rng, sub_obj, AD, q0, model, params, restructure)
grad = mean(1:length(sub_obj.subsampling)) do _
# Fixing the RNG so that the same Monte Carlo samples are used across the batches
rng = StableRNG(seed)
Expand Down
49 changes: 49 additions & 0 deletions test/models/subsamplednormals.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

struct SubsampledNormals{D<:Normal,F<:Real}
dists::Vector{D}
likeadj::F
end

function SubsampledNormals(rng::Random.AbstractRNG, n_normals::Int)
μs = randn(rng, n_normals)
σs = ones(n_normals)
dists = Normal.(μs, σs)
return SubsampledNormals{eltype(dists),Float64}(dists, 1.0)
end

function LogDensityProblems.logdensity(m::SubsampledNormals, x)
(; likeadj, dists) = m
return likeadj*mapreduce(Base.Fix2(logpdf, only(x)), +, dists)
end

function LogDensityProblems.logdensity_and_gradient(m::SubsampledNormals, x)
return (
LogDensityProblems.logdensity(m, x),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x),
)
end

function LogDensityProblems.logdensity_gradient_and_hessian(m::SubsampledNormals, x)
return (
LogDensityProblems.logdensity(m, x),
ForwardDiff.gradient(Base.Fix1(LogDensityProblems.logdensity, m), x),
ForwardDiff.hessian(Base.Fix1(LogDensityProblems.logdensity, m), x),
)
end

function LogDensityProblems.capabilities(::Type{<:SubsampledNormals})
return LogDensityProblems.LogDensityOrder{2}()
end

function AdvancedVI.subsample(m::SubsampledNormals, idx)
n_data = length(m.dists)
return SubsampledNormals(m.dists[idx], n_data/length(idx))
end

function subsamplednormal(rng::Random.AbstractRNG, n_data::Int)
model = SubsampledNormals(rng, n_data)
n_dims = 1
μ_true = [mean([mean(dist) for dist in model.dists])]
L_true = Diagonal(ones(1))
return TestModel(model, μ_true, L_true, n_dims, 1, true)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ struct TestModel{M,L,S,SC}
end
include("models/normal.jl")
include("models/normallognormal.jl")
include("models/subsamplednormals.jl")

if GROUP == "All" || GROUP == "GENERAL"
# Tests that do not need to check correct integration with AD backends
include("general/optimize.jl")
include("general/proximal_location_scale_entropy.jl")
include("general/rules.jl")
Expand Down
Loading