diff --git a/Project.toml b/Project.toml index c7ddbd00..513ec4d8 100644 --- a/Project.toml +++ b/Project.toml @@ -4,12 +4,14 @@ authors = ["TuringLang"] version = "0.2.0" [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] +AbstractMCMC = "2" Distributions = "0.23, 0.24" Libtask = "0.5" StatsFuns = "0.9" diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index 4805a767..95d0974b 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -1,5 +1,6 @@ module AdvancedPS +import AbstractMCMC import Distributions import Libtask import Random @@ -7,5 +8,7 @@ import StatsFuns include("resampling.jl") include("container.jl") +include("smc.jl") +include("model.jl") end diff --git a/src/container.jl b/src/container.jl index ac20d477..7fdaf230 100644 --- a/src/container.jl +++ b/src/container.jl @@ -30,7 +30,7 @@ advance!(t::Trace) = Libtask.consume(t.ctask) # reset log probability reset_logprob!(t::Trace) = nothing -reset_model(f) = nothing +reset_model(f) = deepcopy(f) delete_retained!(f) = nothing # Task copying version of fork for Trace. @@ -59,7 +59,7 @@ function forkr(trace::Trace) # add backward reference newtrace = Trace(newf, ctask) addreference!(ctask.task, newtrace) - + return newtrace end @@ -96,6 +96,11 @@ Base.collect(pc::ParticleContainer) = pc.vals Base.length(pc::ParticleContainer) = length(pc.vals) Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i] +function Base.rand(rng::Random.AbstractRNG, pc::ParticleContainer) + index = randcat(rng, getweights(pc)) + return pc[index] +end + # registers a new x-particle in the container function Base.push!(pc::ParticleContainer, p::Particle) push!(pc.vals, p) @@ -319,7 +324,7 @@ function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler) logZ0 = logZ(pc) # Reweight the particles by including the first observation ``y₁``. - isdone = reweight!(rng, pc) + isdone = reweight!(pc) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) @@ -337,7 +342,7 @@ function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler) logZ0 = logZ(pc) # Reweight the particles by including the next observation ``yₜ``. - isdone = reweight!(rng, pc) + isdone = reweight!(pc) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) diff --git a/src/model.jl b/src/model.jl new file mode 100644 index 00000000..d2028389 --- /dev/null +++ b/src/model.jl @@ -0,0 +1,8 @@ +""" + observe(dist::Distribution, x) + +Observe sample `x` from distribution `dist` and yield its log-likelihood value. +""" +function observe(dist::Distributions.Distribution, x) + return Libtask.produce(Distributions.loglikelihood(dist, x)) +end diff --git a/src/smc.jl b/src/smc.jl index ac5a5c11..da5523ae 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -1,351 +1,129 @@ -### -### Particle Filtering and Particle MCMC Samplers. -### - -#### -#### Generic Sequential Monte Carlo sampler. -#### - -""" -$(TYPEDEF) - -Sequential Monte Carlo sampler. - -# Fields - -$(TYPEDFIELDS) -""" -struct SMC{space, R} <: ParticleInference +struct SMC{R} <: AbstractMCMC.AbstractSampler + nparticles::Int resampler::R end """ - SMC(space...) - SMC([resampler = ResampleWithESSThreshold(), space = ()]) - SMC([resampler = resample_systematic, ]threshold[, space = ()]) + SMC(n[, resampler = ResampleWithESSThreshold()]) + SMC(n, [resampler = resample_systematic, ]threshold) -Create a sequential Monte Carlo sampler of type [`SMC`](@ref) for the variables in `space`. +Create a sequential Monte Carlo (SMC) sampler with `n` particles. If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function SMC(resampler = Turing.Core.ResampleWithESSThreshold(), space::Tuple = ()) - return SMC{space, typeof(resampler)}(resampler) -end +SMC(nparticles::Int) = SMC(nparticles, ResampleWithESSThreshold()) # Convenient constructors with ESS threshold -function SMC(resampler, threshold::Real, space::Tuple = ()) - return SMC(Turing.Core.ResampleWithESSThreshold(resampler, threshold), space) +function SMC(nparticles::Int, resampler, threshold::Real) + return SMC(nparticles, ResampleWithESSThreshold(resampler, threshold)) end -SMC(threshold::Real, space::Tuple = ()) = SMC(resample_systematic, threshold, space) - -# If only the space is defined -SMC(space::Symbol...) = SMC(space) -SMC(space::Tuple) = SMC(Turing.Core.ResampleWithESSThreshold(), space) - -struct SMCTransition{T,F<:AbstractFloat} - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The weight of the particle the sample was retrieved from." - weight::F -end - -function SMCTransition(vi::AbstractVarInfo, weight) - theta = tonamedtuple(vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) - - return SMCTransition(theta, lp, weight) -end - -metadata(t::SMCTransition) = (lp = t.lp, weight = t.weight) - -DynamicPPL.getlogp(t::SMCTransition) = t.lp +SMC(nparticles::Int, threshold::Real) = SMC(nparticles, resample_systematic, threshold) -struct SMCState{P,F<:AbstractFloat} - particles::P - particleindex::Int - # The logevidence after aggregating all samples together. - average_logevidence::F +struct SMCSample{P,W,L} + trajectories::P + weights::W + logevidence::L end -function getlogevidence(samples, sampler::Sampler{<:SMC}, state::SMCState) - return state.average_logevidence +function AbstractMCMC.sample(model::AbstractMCMC.AbstractModel, sampler::SMC; kwargs...) + return AbstractMCMC.sample(Random.GLOBAL_RNG, model, sampler; kwargs...) end function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - sampler::Sampler{<:SMC}, - N::Integer; - chain_type=MCMCChains.Chains, - resume_from=nothing, - progress=PROGRESS[], + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::SMC; kwargs... ) - if resume_from === nothing - return AbstractMCMC.mcmcsample(rng, model, sampler, N; - chain_type=chain_type, - progress=progress, - nparticles=N, - kwargs...) - else - return resume(resume_from, N; - chain_type=chain_type, progress=progress, nparticles=N, kwargs...) + if !isempty(kwargs) + @warn "keyword arguments $(keys(kwargs)) are not supported by `SMC`" end -end -function DynamicPPL.initialstep( - ::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:SMC}, - vi::AbstractVarInfo; - nparticles::Int, - kwargs... -) - # Reset the VarInfo. - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl) - resetlogp!(vi) - empty!(vi) - - # Create a new set of particles. - T = Trace{typeof(spl),typeof(vi),typeof(model)} - particles = ParticleContainer(T[Trace(model, spl, vi) for _ in 1:nparticles]) + # Create a set of particles. + particles = ParticleContainer([Trace(model) for _ in 1:sampler.nparticles]) # Perform particle sweep. - logevidence = sweep!(particles, spl.alg.resampler) - - # Extract the first particle and its weight. - particle = particles.vals[1] - weight = getweight(particles, 1) - - # Compute the first transition and the first state. - transition = SMCTransition(particle.vi, weight) - state = SMCState(particles, 2, logevidence) - - return transition, state -end + logevidence = sweep!(rng, particles, sampler.resampler) -function AbstractMCMC.step( - ::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:SMC}, - state::SMCState; - kwargs... -) - # Extract the index of the current particle. - index = state.particleindex - - # Extract the current particle and its weight. - particles = state.particles - particle = particles.vals[index] - weight = getweight(particles, index) - - # Compute the transition and the next state. - transition = SMCTransition(particle.vi, weight) - nextstate = SMCState(state.particles, index + 1, state.average_logevidence) - - return transition, nextstate + return SMCSample(collect(particles), getweights(particles), logevidence) end -#### -#### Particle Gibbs sampler. -#### - -""" -$(TYPEDEF) - -Particle Gibbs sampler. - -Note that this method is particle-based, and arrays of variables -must be stored in a [`TArray`](@ref) object. - -# Fields - -$(TYPEDFIELDS) -""" -struct PG{space,R} <: ParticleInference +struct PG{R} <: AbstractMCMC.AbstractSampler """Number of particles.""" nparticles::Int """Resampling algorithm.""" resampler::R end -isgibbscomponent(::PG) = true - """ - PG(n, space...) - PG(n, [resampler = ResampleWithESSThreshold(), space = ()]) - PG(n, [resampler = resample_systematic, ]threshold[, space = ()]) + PG(n, [resampler = ResampleWithESSThreshold()]) + PG(n, [resampler = resample_systematic, ]threshold) -Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles for the variables -in `space`. +Create a Particle Gibbs sampler with `n` particles. If the algorithm for the resampling step is not specified explicitly, systematic resampling is performed if the estimated effective sample size per particle drops below 0.5. """ -function PG( - nparticles::Int, - resampler = Turing.Core.ResampleWithESSThreshold(), - space::Tuple = (), -) - return PG{space, typeof(resampler)}(nparticles, resampler) -end +PG(nparticles::Int) = PG(nparticles, ResampleWithESSThreshold()) # Convenient constructors with ESS threshold -function PG(nparticles::Int, resampler, threshold::Real, space::Tuple = ()) - return PG(nparticles, Turing.Core.ResampleWithESSThreshold(resampler, threshold), space) -end -function PG(nparticles::Int, threshold::Real, space::Tuple = ()) - return PG(nparticles, resample_systematic, threshold, space) -end - -# If only the number of particles and the space is defined -PG(nparticles::Int, space::Symbol...) = PG(nparticles, space) -function PG(nparticles::Int, space::Tuple) - return PG(nparticles, Turing.Core.ResampleWithESSThreshold(), space) +function PG(nparticles::Int, resampler, threshold::Real) + return PG(nparticles, ResampleWithESSThreshold(resampler, threshold)) end +PG(nparticles::Int, threshold::Real) = PG(nparticles, resample_systematic, threshold) -const CSMC = PG # type alias of PG as Conditional SMC - -struct PGTransition{T,F<:AbstractFloat} - "The parameters for any given sample." - θ::T - "The joint log probability of the sample (NOTE: does not work, always set to zero)." - lp::F - "The log evidence of the sample." - logevidence::F +struct PGState{T} + trajectory::T end -function PGTransition(vi::AbstractVarInfo, logevidence) - theta = tonamedtuple(vi) - - # This is pretty useless since we reset the log probability continuously in the - # particle sweep. - lp = getlogp(vi) - - return PGTransition(theta, lp, logevidence) +struct PGSample{T,L} + trajectory::T + logevidence::L end -metadata(t::PGTransition) = (lp = t.lp, logevidence = t.logevidence) - -DynamicPPL.getlogp(t::PGTransition) = t.lp - -function getlogevidence(samples, sampler::Sampler{<:PG}, vi::AbstractVarInfo) - return mean(x.logevidence for x in samples) -end - -function DynamicPPL.initialstep( - rng::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:PG}, - vi::AbstractVarInfo; - kwargs... +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::PG; + kwargs..., ) - # Reset the VarInfo before new sweep - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl) - resetlogp!(vi) - - # Create a new set of particles - num_particles = spl.alg.nparticles - T = Trace{typeof(spl),typeof(vi),typeof(model)} - particles = ParticleContainer(T[Trace(model, spl, vi) for _ in 1:num_particles]) + # Create a new set of particles. + particles = ParticleContainer([Trace(model) for _ in 1:sampler.nparticles]) # Perform a particle sweep. - logevidence = sweep!(particles, spl.alg.resampler) + logevidence = sweep!(rng, particles, sampler.resampler) # Pick a particle to be retained. - Ws = getweights(particles) - indx = randcat(Ws) - reference = particles.vals[indx] + trajectory = rand(rng, particles) - # Compute the first transition. - _vi = reference.vi - transition = PGTransition(_vi, logevidence) - - return transition, _vi + return PGSample(trajectory, logevidence), PGState(trajectory) end function AbstractMCMC.step( - ::AbstractRNG, - model::AbstractModel, - spl::Sampler{<:PG}, - vi::AbstractVarInfo; + rng::Random.AbstractRNG, + model::AbstractMCMC.AbstractModel, + sampler::PG, + state::PGState; kwargs... ) - # Reset the VarInfo before new sweep. - reset_num_produce!(vi) - set_retained_vns_del_by_spl!(vi, spl) - resetlogp!(vi) - # Create a new set of particles. - num_particles = spl.alg.nparticles - T = Trace{typeof(spl),typeof(vi),typeof(model)} - x = Vector{T}(undef, num_particles) - @inbounds for i in 1:(num_particles - 1) - x[i] = Trace(model, spl, vi) + nparticles = sampler.nparticles + x = map(1:nparticles) do i + if i == nparticles + # Create reference trajectory. + forkr(state.trajectory) + else + Trace(model) + end end - # Create reference particle. - @inbounds x[num_particles] = forkr(Trace(model, spl, vi)) particles = ParticleContainer(x) # Perform a particle sweep. - logevidence = sweep!(particles, spl.alg.resampler) + logevidence = sweep!(rng, particles, sampler.resampler) # Pick a particle to be retained. - Ws = getweights(particles) - indx = randcat(Ws) - newreference = particles.vals[indx] + newtrajectory = rand(rng, particles) - # Compute the transition. - _vi = newreference.vi - transition = PGTransition(_vi, logevidence) - - return transition, _vi + return PGSample(newtrajectory, logevidence), PGState(newtrajectory) end - -function DynamicPPL.assume( - rng, - spl::Sampler{<:Union{PG,SMC}}, - dist::Distribution, - vn::VarName, - ::Any -) - vi = current_trace().vi - if inspace(vn, spl) - if ~haskey(vi, vn) - r = rand(rng, dist) - push!(vi, vn, r, dist, spl) - elseif is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = rand(rng, dist) - vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - updategid!(vi, vn, spl) - r = vi[vn] - end - else # vn belongs to other sampler <=> conditionning on vn - if haskey(vi, vn) - r = vi[vn] - else - r = rand(rng, dist) - push!(vi, vn, r, dist, Selector(:invalid)) - end - lp = logpdf_with_trans(dist, r, istrans(vi, vn)) - acclogp!(vi, lp) - end - return r, 0 -end - -function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - produce(logpdf(dist, value)) - return 0 -end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 3aacf35c..6d181a79 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,12 @@ [deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +AbstractMCMC = "2" +Distributions = "0.24" Libtask = "0.5" julia = "1.3" diff --git a/test/runtests.jl b/test/runtests.jl index 2ceb1147..7a3cd3c7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,6 @@ using AdvancedPS +using AbstractMCMC +using Distributions using Libtask using Random using Test @@ -6,4 +8,5 @@ using Test @testset "AdvancedPS.jl" begin @testset "Resampling tests" begin include("resampling.jl") end @testset "Container tests" begin include("container.jl") end + @testset "SMC and PG tests" begin include("smc.jl") end end diff --git a/test/smc.jl b/test/smc.jl index ec2bca09..2726b718 100644 --- a/test/smc.jl +++ b/test/smc.jl @@ -1,179 +1,153 @@ -using Turing, Random, Test -using Turing.Core: ResampleWithESSThreshold -using Turing.Inference: getspace, resample_systematic, resample_multinomial - -using Random - -dir = splitdir(splitdir(pathof(Turing))[1])[1] -include(dir*"/test/test_utils/AllUtils.jl") - -@testset "SMC" begin - @turing_testset "constructor" begin - s = SMC() - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === () +@testset "smc.jl" begin + @testset "SMC constructor" begin + sampler = AdvancedPS.SMC(10) + @test sampler.nparticles == 10 + @test sampler.resampler == AdvancedPS.ResampleWithESSThreshold() + + sampler = AdvancedPS.SMC(15, 0.6) + @test sampler.nparticles == 15 + @test sampler.resampler === + AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6) + + sampler = AdvancedPS.SMC(20, AdvancedPS.resample_multinomial, 0.6) + @test sampler.nparticles == 20 + @test sampler.resampler === + AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6) + + sampler = AdvancedPS.SMC(25, AdvancedPS.resample_systematic) + @test sampler.nparticles == 25 + @test sampler.resampler === AdvancedPS.resample_systematic + end - s = SMC(:x) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) + # Smoke tests + @testset "models" begin + mutable struct NormalModel <: AbstractMCMC.AbstractModel + a::Float64 + b::Float64 - s = SMC((:x,)) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) + NormalModel() = new() + end - s = SMC(:x, :y) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) + function (m::NormalModel)() + # First latent variable. + m.a = a = rand(Normal(4, 5)) - s = SMC((:x, :y)) - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) + # First observation. + AdvancedPS.observe(Normal(a, 2), 3) - s = SMC(0.6) - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === () + # Second latent variable. + m.b = b = rand(Normal(a, 1)) - s = SMC(0.6, (:x,)) - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === (:x,) + # Second observation. + AdvancedPS.observe(Normal(b, 2), 1.5) + end + sample(NormalModel(), AdvancedPS.SMC(100)) - s = SMC(resample_multinomial, 0.6) - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === () + # failing test + mutable struct FailSMCModel <: AbstractMCMC.AbstractModel + a::Float64 + b::Float64 - s = SMC(resample_multinomial, 0.6, (:x,)) - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === (:x,) + FailSMCModel() = new() + end - s = SMC(resample_systematic) - @test s.resampler === resample_systematic - @test getspace(s) === () + function (m::FailSMCModel)() + m.a = a = rand(Normal(4, 5)) + m.b = b = rand(Normal(a, 1)) + if a >= 4 + AdvancedPS.observe(Normal(b, 2), 1.5) + end + end - s = SMC(resample_systematic, (:x,)) - @test s.resampler === resample_systematic - @test getspace(s) === (:x,) + @test_throws ErrorException sample(FailSMCModel(), AdvancedPS.SMC(100)) end - @turing_testset "models" begin - @model function normal() - a ~ Normal(4,5) - 3 ~ Normal(a,2) - b ~ Normal(a,1) - 1.5 ~ Normal(b,2) - a, b - end + @testset "logevidence" begin + Random.seed!(100) - tested = sample(normal(), SMC(), 100); + mutable struct TestModel <: AbstractMCMC.AbstractModel + a::Float64 + x::Bool + b::Float64 + c::Float64 - # failing test - @model function fail_smc() - a ~ Normal(4,5) - 3 ~ Normal(a,2) - b ~ Normal(a,1) - if a >= 4.0 - 1.5 ~ Normal(b,2) - end - a, b + TestModel() = new() end - @test_throws ErrorException sample(fail_smc(), SMC(), 100) - end + function (m::TestModel)() + # First hidden variables. + m.a = rand(Normal(0, 1)) + m.x = x = rand(Bernoulli(1)) + m.b = rand(Gamma(2, 3)) - @turing_testset "logevidence" begin - Random.seed!(100) + # First observation. + AdvancedPS.observe(Bernoulli(x / 2), 1) - @model function test() - a ~ Normal(0, 1) - x ~ Bernoulli(1) - b ~ Gamma(2, 3) - 1 ~ Bernoulli(x / 2) - c ~ Beta() - 0 ~ Bernoulli(x / 2) - x + # Second hidden variable. + m.c = rand(Beta()) + + # Second observation. + AdvancedPS.observe(Bernoulli(x / 2), 0) end - chains_smc = sample(test(), SMC(), 100) + chains_smc = sample(TestModel(), AdvancedPS.SMC(100)) - @test all(isone, chains_smc[:x]) + @test all(isone(p.f.x) for p in chains_smc.trajectories) @test chains_smc.logevidence ≈ -2 * log(2) end -end -@testset "PG" begin - @turing_testset "constructor" begin - s = PG(10) - @test s.nparticles == 10 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === () - - s = PG(20, :x) - @test s.nparticles == 20 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = PG(30, (:x,)) - @test s.nparticles == 30 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x,) - - s = PG(40, :x, :y) - @test s.nparticles == 40 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - - s = PG(50, (:x, :y)) - @test s.nparticles == 50 - @test s.resampler == ResampleWithESSThreshold() - @test getspace(s) === (:x, :y) - - s = PG(60, 0.6) - @test s.nparticles == 60 - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === () - - s = PG(70, 0.6, (:x,)) - @test s.nparticles == 70 - @test s.resampler === ResampleWithESSThreshold(resample_systematic, 0.6) - @test getspace(s) === (:x,) - - s = PG(80, resample_multinomial, 0.6) - @test s.nparticles == 80 - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === () - - s = PG(90, resample_multinomial, 0.6, (:x,)) - @test s.nparticles == 90 - @test s.resampler === ResampleWithESSThreshold(resample_multinomial, 0.6) - @test getspace(s) === (:x,) - - s = PG(100, resample_systematic) - @test s.nparticles == 100 - @test s.resampler === resample_systematic - @test getspace(s) === () - - s = PG(110, resample_systematic, (:x,)) - @test s.nparticles == 110 - @test s.resampler === resample_systematic - @test getspace(s) === (:x,) + @testset "PG constructor" begin + sampler = AdvancedPS.PG(10) + @test sampler.nparticles == 10 + @test sampler.resampler == AdvancedPS.ResampleWithESSThreshold() + + sampler = AdvancedPS.PG(60, 0.6) + @test sampler.nparticles == 60 + @test sampler.resampler === + AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_systematic, 0.6) + + sampler = AdvancedPS.PG(80, AdvancedPS.resample_multinomial, 0.6) + @test sampler.nparticles == 80 + @test sampler.resampler === + AdvancedPS.ResampleWithESSThreshold(AdvancedPS.resample_multinomial, 0.6) + + sampler = AdvancedPS.PG(100, AdvancedPS.resample_systematic) + @test sampler.nparticles == 100 + @test sampler.resampler === AdvancedPS.resample_systematic end - @turing_testset "logevidence" begin + @testset "logevidence" begin Random.seed!(100) - @model function test() - a ~ Normal(0, 1) - x ~ Bernoulli(1) - b ~ Gamma(2, 3) - 1 ~ Bernoulli(x / 2) - c ~ Beta() - 0 ~ Bernoulli(x / 2) - x + mutable struct TestModel <: AbstractMCMC.AbstractModel + a::Float64 + x::Bool + b::Float64 + c::Float64 + + TestModel() = new() + end + + function (m::TestModel)() + # First hidden variables. + m.a = rand(Normal(0, 1)) + m.x = x = rand(Bernoulli(1)) + m.b = rand(Gamma(2, 3)) + + # First observation. + AdvancedPS.observe(Bernoulli(x / 2), 1) + + # Second hidden variable. + m.c = rand(Beta()) + + # Second observation. + AdvancedPS.observe(Bernoulli(x / 2), 0) end - chains_pg = sample(test(), PG(10), 100) + chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100) - @test all(isone, chains_pg[:x]) - @test chains_pg.logevidence ≈ -2 * log(2) atol = 0.01 + @test all(isone(p.trajectory.f.x) for p in chains_pg) + @test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01 end end