Skip to content

Commit

Permalink
fix bug for bijector with 1 MC sample with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Dec 22, 2023
1 parent 9ebfc3f commit 05dbb51
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 40 deletions.
48 changes: 29 additions & 19 deletions ext/AdvancedVIBijectorsExt.jl
Expand Up @@ -11,35 +11,45 @@ else
using ..Random
end

function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
base_samples = rand(rng, q_base, n_samples)
it = AdvancedVI.eachsample(base_samples)
sample_init = first(it)
function transform_samples_with_jacobian(unconst_samples, transform, n_samples)
unconst_iter = AdvancedVI.eachsample(unconst_samples)
unconst_init = first(unconst_iter)

samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init)

samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(it, 1);
init=with_logabsdet_jacobian(transform, sample_init)
Iterators.drop(unconst_iter, 1);
init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)
logjac = last(samples_and_logjac)/n_samples
samples, logjac
end

entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
function AdvancedVI.reparam_with_entropy(
rng ::Random.AbstractRNG,
q ::Bijectors.TransformedDistribution,
q_stop ::Bijectors.TransformedDistribution,
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
)

entropy = entropy_base + logjac/n_samples
# Apply bijector to samples while estimating its jacobian
samples, logjac = transform_samples_with_jacobian(
unconst_samples, transform, n_samples
)
entropy = unconst_entropy + logjac
samples, entropy
end
end
4 changes: 4 additions & 0 deletions src/utils.jl
Expand Up @@ -34,3 +34,7 @@ function catsamples_and_acc(
return (x, ∑y)
end

function samples_expand_dim(x::AbstractVector)
reshape(x, (:,1))
end

11 changes: 6 additions & 5 deletions test/inference/repgradelbo_distributionsad.jl
Expand Up @@ -9,9 +9,10 @@ using Test
(modelname, modelconstr) Dict(
:Normal=> normal_meanfield,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
Expand All @@ -33,7 +34,7 @@ using Test
q0 = TuringDiagMvNormal(μ0, diag(L0))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -45,7 +46,7 @@ using Test
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale.jl
Expand Up @@ -5,16 +5,17 @@ using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
Expand All @@ -37,7 +38,7 @@ using Test
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -49,7 +50,7 @@ using Test
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Expand Up @@ -5,15 +5,16 @@ using Test

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
Expand Down Expand Up @@ -42,7 +43,7 @@ using Test
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -54,7 +55,7 @@ using Test
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down

0 comments on commit 05dbb51

Please sign in to comment.