diff --git a/Project.toml b/Project.toml index ff85db03..5839fafd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,18 +1,20 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.2.4" +version = "0.3.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] AbstractMCMC = "2, 3" Distributions = "0.23, 0.24, 0.25" Libtask = "0.5.3" +Random123 = "1.3" StatsFuns = "0.9" julia = "1.3" diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index ce2ecf18..e1c7832d 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -5,8 +5,10 @@ using Distributions: Distributions using Libtask: Libtask using Random: Random using StatsFuns: StatsFuns +using Random123: Random123 include("resampling.jl") +include("rng.jl") include("container.jl") include("smc.jl") include("model.jl") diff --git a/src/container.jl b/src/container.jl index 504bc329..695892ea 100644 --- a/src/container.jl +++ b/src/container.jl @@ -1,31 +1,43 @@ -struct Trace{F} +struct Trace{F,U,N,V<:Random123.AbstractR123{U}} f::F ctask::Libtask.CTask + rng::TracedRNG{U,N,V} end const Particle = Trace -function Trace(f) +function Trace(f, rng::TracedRNG) ctask = let f = f Libtask.CTask() do - res = f() + res = f(rng) Libtask.produce(nothing) return res end end # add backward reference - newtrace = Trace(f, ctask) + newtrace = Trace(f, ctask, rng) addreference!(ctask.task, newtrace) return newtrace end -Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) +function Trace(f, ctask::Libtask.CTask) + return Trace(f, ctask, TracedRNG()) +end + +# Copy task +Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask), deepcopy(trace.rng)) # step to the next observe statement and # return the log probability of the transition (or nothing if done) -advance!(t::Trace) = Libtask.consume(t.ctask) +function advance!(t::Trace, isref::Bool) + isref ? load_state!(t.rng) : save_state!(t.rng) + inc_counter!(t.rng) + + # Move to next step + return Libtask.consume(t.ctask) +end # reset log probability reset_logprob!(t::Trace) = nothing @@ -48,16 +60,18 @@ end # Create new task and copy randomness function forkr(trace::Trace) newf = reset_model(trace.f) + Random123.set_counter!(trace.rng, 1) + ctask = let f = trace.ctask.task.code Libtask.CTask() do - res = f() + res = f()(trace.rng) Libtask.produce(nothing) return res end end # add backward reference - newtrace = Trace(newf, ctask) + newtrace = Trace(newf, ctask, trace.rng) addreference!(ctask.task, newtrace) return newtrace @@ -81,15 +95,21 @@ Data structure for particle filters - normalise!(pc::ParticleContainer) - consume(pc::ParticleContainer): return incremental likelihood """ -mutable struct ParticleContainer{T<:Particle} +mutable struct ParticleContainer{T<:Particle,U,N,V<:Random123.AbstractR123{U}} "Particles." vals::Vector{T} "Unnormalized logarithmic weights." logWs::Vector{Float64} + "Traced RNG to replay the resampling step" + rng::TracedRNG{U,N,V} end function ParticleContainer(particles::Vector{<:Particle}) - return ParticleContainer(particles, zeros(length(particles))) + return ParticleContainer(particles, zeros(length(particles)), TracedRNG()) +end + +function ParticleContainer(particles::Vector{<:Particle}, r::TracedRNG) + return ParticleContainer(particles, zeros(length(particles)), r) end Base.collect(pc::ParticleContainer) = pc.vals @@ -116,7 +136,10 @@ function Base.copy(pc::ParticleContainer) # copy weights logWs = copy(pc.logWs) - return ParticleContainer(vals, logWs) + # Copy rng and states + rng = copy(pc.rng) + + return ParticleContainer(vals, logWs, rng) end """ @@ -170,6 +193,22 @@ function effectiveSampleSize(pc::ParticleContainer) return inv(sum(abs2, Ws)) end +""" + update_keys!(pc::ParticleContainer) + +Create new unique keys for the particles in the ParticleContainer +""" +function update_keys!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) + # Update keys to new particle ids + nparticles = length(pc) + n = ref === nothing ? nparticles : nparticles - 1 + for i in 1:n + pi = pc.vals[i] + k = split(pi.rng.rng.key) + Random.seed!(pi.rng, k[1]) + end +end + """ resample_propagate!(rng, pc::ParticleContainer[, randcat = resample_systematic, ref = nothing; weights = getweights(pc)]) @@ -213,11 +252,17 @@ function resample_propagate!( pi = particles[i] isref = pi === ref p = isref ? fork(pi, isref) : pi - children[j += 1] = p + nseeds = isref ? ni - 1 : ni + + seeds = split(p.rng.rng.key, nseeds) + !isref && Random.seed!(p.rng, seeds[1]) + children[j += 1] = p # fork additional children - for _ in 2:ni - children[j += 1] = fork(p, isref) + for k in 2:ni + part = fork(p, isref) + Random.seed!(part.rng, seeds[k]) + children[j += 1] = part end end end @@ -247,6 +292,8 @@ function resample_propagate!( if ess ≤ resampler.threshold * length(pc) resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights) + else + update_keys!(pc, ref) end return pc @@ -258,7 +305,7 @@ end Check if the final time step is reached, and otherwise reweight the particles by considering the next observation. """ -function reweight!(pc::ParticleContainer) +function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) n = length(pc) particles = collect(pc) @@ -270,7 +317,8 @@ function reweight!(pc::ParticleContainer) # the execution of the model is finished. # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and # ``θᵢ`` are variables of other samplers. - score = advance!(p) + isref = p === ref + score = advance!(p, isref) if score === nothing numdone += 1 @@ -321,7 +369,6 @@ function sweep!( ref::Union{Particle,Nothing}=nothing, ) # Initial step: - # Resample and propagate particles. resample_propagate!(rng, pc, resampler, ref) @@ -333,7 +380,7 @@ function sweep!( logZ0 = logZ(pc) # Reweight the particles by including the first observation ``y₁``. - isdone = reweight!(pc) + isdone = reweight!(pc, ref) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) @@ -351,7 +398,7 @@ function sweep!( logZ0 = logZ(pc) # Reweight the particles by including the next observation ``yₜ``. - isdone = reweight!(pc) + isdone = reweight!(pc, ref) # Compute the normalizing constant ``Z₁`` after reweighting. logZ1 = logZ(pc) diff --git a/src/rng.jl b/src/rng.jl new file mode 100644 index 00000000..a55f8cc9 --- /dev/null +++ b/src/rng.jl @@ -0,0 +1,85 @@ +# Default RNG type for when nothing is specified +const _BASE_RNG = Random123.Philox2x + +""" + TracedRNG{R,N,T} + +Wrapped random number generator from Random123 to keep track of random streams during model evaluation +""" +mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG + "Model step counter" + count::Int + "Inner RNG" + rng::T + "Array of keys" + keys::Array{R,N} +end + +""" + TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG()) +Create a `TracedRNG` with `r` as the inner RNG. +""" +function TracedRNG(r::Random123.AbstractR123=_BASE_RNG()) + Random123.set_counter!(r, 0) + return TracedRNG(1, r, typeof(r.key)[]) +end + +# Connect to the Random API +Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng) +Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T) + +""" + split(key::Integer, n::Integer=1) + +Split `key` into `n` new keys +""" +function split(key::Integer, n::Integer=1) + T = typeof(key) # Make sure the type of `key` is consistent on W32 and W64 systems. + return T[hash(key, i) for i in UInt(1):UInt(n)] +end + +""" + load_state!(r::TracedRNG) + +Load state from current model iteration. Random streams are now replayed +""" +function load_state!(rng::TracedRNG) + key = rng.keys[rng.count] + Random.seed!(rng.rng, key) + return Random123.set_counter!(rng.rng, 0) +end + +""" + update_rng!(rng::TracedRNG) + +Set key and counter of inner rng in `rng` to `key` and the running model step to 0 +""" +function Random.seed!(rng::TracedRNG, key) + Random.seed!(rng.rng, key) + return Random123.set_counter!(rng.rng, 0) +end + +""" + save_state!(r::TracedRNG) + +Add current key of the inner rng in `r` to `keys`. +""" +function save_state!(r::TracedRNG) + return push!(r.keys, r.rng.key) +end + +Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys)) + +""" + set_counter!(r::TracedRNG, n::Integer) + +Set the counter of the inner rng in `r`, used to keep track of the current model step +""" +Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n + +""" + inc_counter!(r::TracedRNG, n::Integer=1) + +Increase the model step counter by `n` +""" +inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n diff --git a/src/smc.jl b/src/smc.jl index 89d465c8..433780ba 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -38,7 +38,9 @@ function AbstractMCMC.sample( end # Create a set of particles. - particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)]) + particles = ParticleContainer( + [Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG() + ) # Perform particle sweep. logevidence = sweep!(rng, particles, sampler.resampler) @@ -83,7 +85,9 @@ function AbstractMCMC.step( rng::Random.AbstractRNG, model::AbstractMCMC.AbstractModel, sampler::PG; kwargs... ) # Create a new set of particles. - particles = ParticleContainer([Trace(model) for _ in 1:(sampler.nparticles)]) + particles = ParticleContainer( + [Trace(model, TracedRNG()) for _ in 1:(sampler.nparticles)], TracedRNG() + ) # Perform a particle sweep. logevidence = sweep!(rng, particles, sampler.resampler) @@ -108,10 +112,10 @@ function AbstractMCMC.step( # Create reference trajectory. forkr(state.trajectory) else - Trace(model) + Trace(model, TracedRNG()) end end - particles = ParticleContainer(x) + particles = ParticleContainer(x, TracedRNG()) # Perform a particle sweep. logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles]) diff --git a/test/Project.toml b/test/Project.toml index a03a636d..d647de7f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] @@ -10,3 +11,4 @@ AbstractMCMC = "2, 3" Distributions = "0.24, 0.25" Libtask = "0.5" julia = "1.3" +Random123 = "1.3" \ No newline at end of file diff --git a/test/container.jl b/test/container.jl index 69c524e5..1dbc33df 100644 --- a/test/container.jl +++ b/test/container.jl @@ -11,7 +11,7 @@ # Create a resumable function that always returns the same log probability. function fpc(logp) f = let logp = logp - () -> begin + (rng) -> begin while true produce(logp) end @@ -22,7 +22,7 @@ # Create particle container. logps = [0.0, -1.0, -2.0] - particles = [AdvancedPS.Trace(fpc(logp)) for logp in logps] + particles = [AdvancedPS.Trace(fpc(logp), AdvancedPS.TracedRNG()) for logp in logps] pc = AdvancedPS.ParticleContainer(particles) # Initial state. @@ -52,7 +52,9 @@ @test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps)) # Resample and propagate particles with reference particle - particles_ref = [AdvancedPS.Trace(fpc(logp)) for logp in logps] + particles_ref = [ + AdvancedPS.Trace(fpc(logp), AdvancedPS.TracedRNG()) for logp in logps + ] pc_ref = AdvancedPS.ParticleContainer(particles_ref) AdvancedPS.resample_propagate!( Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, particles_ref[end] @@ -95,7 +97,7 @@ @testset "trace" begin n = Ref(0) - function f2() + function f2(rng) t = TArray(Int, 1) t[1] = 0 while true @@ -107,7 +109,7 @@ end # Test task copy version of trace - tr = AdvancedPS.Trace(f2) + tr = AdvancedPS.Trace(f2, AdvancedPS.TracedRNG()) consume(tr.ctask) consume(tr.ctask) diff --git a/test/rng.jl b/test/rng.jl new file mode 100644 index 00000000..619e9fde --- /dev/null +++ b/test/rng.jl @@ -0,0 +1,24 @@ +@testset "rng.jl" begin + @testset "sample distribution" begin + rng = AdvancedPS.TracedRNG() + vns = rand(rng, Distributions.Normal()) + AdvancedPS.save_state!(rng) + + rand(rng, Distributions.Normal()) + + AdvancedPS.load_state!(rng) + new_vns = rand(rng, Distributions.Normal()) + @test new_vns ≈ vns + end + + @testset "split" begin + rng = AdvancedPS.TracedRNG() + key = rng.rng.key + new_key, = AdvancedPS.split(key, 1) + + @test key ≠ new_key + + Random.seed!(rng, new_key) + @test rng.rng.key === new_key + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 5497a17e..32adc400 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,4 +15,7 @@ using Test @testset "SMC and PG tests" begin include("smc.jl") end + @testset "RNG tests" begin + include("rng.jl") + end end diff --git a/test/smc.jl b/test/smc.jl index 1f98eb37..4de676b4 100644 --- a/test/smc.jl +++ b/test/smc.jl @@ -28,19 +28,20 @@ NormalModel() = new() end - function (m::NormalModel)() + function (m::NormalModel)(rng::Random.AbstractRNG) # First latent variable. - m.a = a = rand(Normal(4, 5)) + m.a = a = rand(rng, Normal(4, 5)) # First observation. AdvancedPS.observe(Normal(a, 2), 3) # Second latent variable. - m.b = b = rand(Normal(a, 1)) + m.b = b = rand(rng, Normal(a, 1)) # Second observation. return AdvancedPS.observe(Normal(b, 2), 1.5) end + sample(NormalModel(), AdvancedPS.SMC(100)) # failing test @@ -51,9 +52,9 @@ FailSMCModel() = new() end - function (m::FailSMCModel)() - m.a = a = rand(Normal(4, 5)) - m.b = b = rand(Normal(a, 1)) + function (m::FailSMCModel)(rng::Random.AbstractRNG) + m.a = a = rand(rng, Normal(4, 5)) + m.b = b = rand(rng, Normal(a, 1)) if a >= 4 AdvancedPS.observe(Normal(b, 2), 1.5) end @@ -74,17 +75,17 @@ TestModel() = new() end - function (m::TestModel)() + function (m::TestModel)(rng::Random.AbstractRNG) # First hidden variables. - m.a = rand(Normal(0, 1)) - m.x = x = rand(Bernoulli(1)) - m.b = rand(Gamma(2, 3)) + m.a = rand(rng, Normal(0, 1)) + m.x = x = rand(rng, Bernoulli(1)) + m.b = rand(rng, Gamma(2, 3)) # First observation. AdvancedPS.observe(Bernoulli(x / 2), 1) # Second hidden variable. - m.c = rand(Beta()) + m.c = rand(rng, Beta()) # Second observation. return AdvancedPS.observe(Bernoulli(x / 2), 0) @@ -128,17 +129,17 @@ TestModel() = new() end - function (m::TestModel)() + function (m::TestModel)(rng::Random.AbstractRNG) # First hidden variables. - m.a = rand(Normal(0, 1)) - m.x = x = rand(Bernoulli(1)) - m.b = rand(Gamma(2, 3)) + m.a = rand(rng, Normal(0, 1)) + m.x = x = rand(rng, Bernoulli(1)) + m.b = rand(rng, Gamma(2, 3)) # First observation. AdvancedPS.observe(Bernoulli(x / 2), 1) # Second hidden variable. - m.c = rand(Beta()) + m.c = rand(rng, Beta()) # Second observation. return AdvancedPS.observe(Bernoulli(x / 2), 0) @@ -149,6 +150,34 @@ @test all(isone(p.trajectory.f.x) for p in chains_pg) @test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01 end + + @testset "Replay reference" begin + mutable struct Model <: AbstractMCMC.AbstractModel + a::Float64 + b::Float64 + + Model() = new() + end + + function (m::Model)(rng) + m.a = rand(rng, Normal()) + AdvancedPS.observe(Normal(), m.a) + + m.b = rand(rng, Normal()) + return AdvancedPS.observe(Normal(), m.b) + end + + pg = AdvancedPS.PG(1) + first, second = sample(Model(), pg, 2) + + first_model = first.trajectory.f + second_model = second.trajectory.f + + # Single Particle - must be replaying + @test first_model.a ≈ second_model.a + @test first_model.b ≈ second_model.b + @test first.logevidence ≈ second.logevidence + end end # @testset "pmmh.jl" begin