Skip to content

Commit

Permalink
Merge 169a014 into c79dce0
Browse files Browse the repository at this point in the history
  • Loading branch information
torfjelde committed Jun 10, 2021
2 parents c79dce0 + 169a014 commit 0af1045
Show file tree
Hide file tree
Showing 8 changed files with 126 additions and 7 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.0"
version = "0.16.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -47,7 +47,7 @@ DocStringExtensions = "0.8"
DynamicPPL = "0.11.0"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5"
Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1"
MCMCChains = "4"
NamedArrays = "0.9"
Reexport = "0.2, 1"
Expand Down
8 changes: 5 additions & 3 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
vns = _getvns(varinfo, sampler)
set_flag!(varinfo, vns[1][1], "del")
for vn in Iterators.flatten(values(vns))
set_flag!(varinfo, vn, "del")
end
p.model(rng, varinfo, sampler)
return varinfo[sampler]
end
Expand Down Expand Up @@ -155,6 +157,6 @@ function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS},
end
end

function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi)
end
5 changes: 3 additions & 2 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ function Bijectors.bijector(
end

bs = Bijectors.bijector.(tuple(dists...))
rs = tuple(ranges...)

if sym2ranges
return (
Bijectors.Stacked(bs, ranges),
Bijectors.Stacked(bs, rs),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
)
else
return Bijectors.Stacked(bs, ranges)
return Bijectors.Stacked(bs, rs)
end
end

Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
Expand Down Expand Up @@ -40,6 +41,7 @@ DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.11.0"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12"
Libtask = "< 0.5.2"
MCMCChains = "4.0.4"
Memoization = "0.1.4"
NamedArrays = "0.9.4"
Expand Down
4 changes: 4 additions & 0 deletions test/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,9 @@
ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, alg, 6000)
check_MoGtest_default(chain, atol = 0.1)

# Different "equivalent" models.
Random.seed!(125)
check_gdemo_models(ESS(), 1_000)
end
end
9 changes: 9 additions & 0 deletions test/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,13 @@
@test isapprox(mle1.values.array, mle2.values.array)
@test isapprox(map1.values.array, map2.values.array)
end

@testset "MAP on $(m.name)" for m in gdemo_models
result = optimize(m, MAP())
@test mean(result.values) 8.0 rtol=0.05
end
@testset "MLE on $(m.name)" for m in gdemo_models
result = optimize(m, MLE())
@test mean(result.values) 10.0 rtol=0.05
end
end
91 changes: 91 additions & 0 deletions test/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,94 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0])

# Declare empty model to make the Sampler constructor work.
@model empty_model() = begin x = 1; end

# A collection of models for which the mean-of-means for the posterior should
# be same.
@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo3(x = 10 * ones(2))
# Multivariate `assume` and `observe`
m ~ MvNormal(length(x), 1.0)
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function gdemo5(x = 10 * ones(1))
# `assume` and `dot_observe`
m ~ Normal()
x .~ Normal(m, 0.5)
end

@model function gdemo6()
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
[10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
end

@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function gdemo8()
# `assume` and literal `dot_observe`
m ~ Normal()
[10.0, ] .~ Normal(m, 0.5)
end

@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function gdemo9()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function _likelihood_dot_observe(m, x)
x ~ MvNormal(m, 0.5 * ones(length(m)))
end

@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)
end

const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo6(), gdemo7(), gdemo8(), gdemo9(), gdemo10())
10 changes: 10 additions & 0 deletions test/test_utils/numerical_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,13 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0)
[1.0, 1.0, 2.0, 2.0, 1.0, 4.0],
atol=atol, rtol=rtol)
end

function check_gdemo_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...)
@testset "$(alg) on $(m.name)" for m in gdemo_models
# Log this so that if something goes wrong, we can identify the
# algorithm and model.
μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...)))

@test μ 8.0 atol=atol rtol=rtol
end
end

0 comments on commit 0af1045

Please sign in to comment.