From eb0bfcf4abf1ae2c1a760d34e3d024376cb14a32 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Dec 2020 00:15:29 +0100 Subject: [PATCH 1/2] Add Trace without Turing --- Project.toml | 4 ++ src/AdvancedPS.jl | 4 ++ src/container.jl | 163 +++++++++++++++++++++++----------------------- src/resampling.jl | 22 ++++--- 4 files changed, 101 insertions(+), 92 deletions(-) diff --git a/Project.toml b/Project.toml index adca95de..81c08f78 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,11 @@ version = "0.1.0" [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] Distributions = "0.23, 0.24" +Libtask = "0.5" +StatsFuns = "0.9" julia = "1.3" diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index 8fcbba3d..cb098ec6 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -1,6 +1,10 @@ module AdvancedPS import Distributions +import Libtask +import StatsFuns include("resampling.jl") +include("container.jl") + end diff --git a/src/container.jl b/src/container.jl index 4f4c9822..2a4d13d5 100644 --- a/src/container.jl +++ b/src/container.jl @@ -1,81 +1,79 @@ -mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model} - model::Tmodel - spl::Tspl - vi::Tvi - ctask::CTask - - function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo) - return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi) - end - function Trace{S}(model::Model, spl::S, vi::AbstractVarInfo) where S<:Sampler - return new{S,typeof(vi),typeof(model)}(model, spl, vi) - end +struct Trace{F} + f::F + ctask::Libtask.CTask end -function Base.copy(trace::Trace) - vi = deepcopy(trace.vi) - res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi) - res.ctask = copy(trace.ctask) - return res -end +const Particle = Trace -# NOTE: this function is called by `forkr` -function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo) - res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) - ctask = CTask() do - res = f() - produce(nothing) - return res - end - task = ctask.task - if task.storage === nothing - task.storage = IdDict() +function Trace(f) + ctask = let f=f + Libtask.CTask() do + res = f() + Libtask.produce(nothing) + return res + end end - task.storage[:turing_trace] = res # create a backward reference in task_local_storage - res.ctask = ctask - return res -end -function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo) - res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) - reset_num_produce!(res.vi) - ctask = CTask() do - res = m(vi, spl) - produce(nothing) - return res - end - task = ctask.task - if task.storage === nothing - task.storage = IdDict() - end - task.storage[:turing_trace] = res # create a backward reference in task_local_storage - res.ctask = ctask - return res + # add backward reference + newtrace = Trace(f, ctask) + addreference!(ctask.task, newtrace) + + return newtrace end -# step to the next observe statement, return log likelihood -Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask)) +Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) + +# step to the next observe statement and +# return the log probability of the transition (or nothing if done) +advance!(t::Trace) = Libtask.consume(t.ctask) + +# reset log probability +reset_logprob!(t::Trace) = nothing + +reset_model(f) = nothing +delete_retained!(f) = nothing # Task copying version of fork for Trace. -function fork(trace :: Trace, is_ref :: Bool = false) +function fork(trace::Trace, isref::Bool = false) newtrace = copy(trace) - is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl) - newtrace.ctask.task.storage[:turing_trace] = newtrace + isref && delete_retained!(newtrace.f) + + # add backward reference + addreference!(newtrace.ctask.task, newtrace) + return newtrace end # PG requires keeping all randomness for the reference particle # Create new task and copy randomness function forkr(trace::Trace) - newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi)) - newtrace.spl = trace.spl - reset_num_produce!(newtrace.vi) + newf = reset_model(trace.f) + ctask = let f=trace.ctask.task.code + Libtask.CTask() do + res = f() + Libtask.produce(nothing) + return res + end + end + + # add backward reference + newtrace = Trace(newf, ctask) + addreference!(ctask.task, newtrace) + return newtrace end -current_trace() = current_task().storage[:turing_trace] +# create a backward reference in task_local_storage +function addreference!(task::Task, trace::Trace) + if task.storage === nothing + task.storage = IdDict() + end + task.storage[:__trace] = trace -const Particle = Trace + return task +end + +current_trace() = current_task().storage[:__trace] """ Data structure for particle filters @@ -141,7 +139,7 @@ end Compute the normalized weights of the particles. """ -getweights(pc::ParticleContainer) = softmax(pc.logWs) +getweights(pc::ParticleContainer) = StatsFuns.softmax(pc.logWs) """ getweight(pc::ParticleContainer, i) @@ -155,7 +153,7 @@ getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc)) Return the logarithm of the normalizing constant of the unnormalized logarithmic weights. """ -logZ(pc::ParticleContainer) = logsumexp(pc.logWs) +logZ(pc::ParticleContainer) = StatsFuns.logsumexp(pc.logWs) """ effectiveSampleSize(pc::ParticleContainer) @@ -168,7 +166,7 @@ function effectiveSampleSize(pc::ParticleContainer) end """ - resample_propagate!(pc::ParticleContainer[, randcat = resample_systematic, ref = nothing; + resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing; weights = getweights(pc)]) Resample and propagate the particles in `pc`. @@ -179,7 +177,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere """ function resample_propagate!( pc::ParticleContainer, - randcat = Turing.Inference.resample_systematic, + randcat = resample, ref::Union{Particle, Nothing} = nothing; weights = getweights(pc) ) @@ -231,6 +229,22 @@ function resample_propagate!( pc end +function resample_propagate!( + pc::ParticleContainer, + resampler::ResampleWithESSThreshold, + ref::Union{Particle,Nothing} = nothing; + weights = getweights(pc) +) + # Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ`` + ess = inv(sum(abs2, weights)) + + if ess ≤ resampler.threshold * length(pc) + resample_propagate!(pc, resampler.resampler, ref; weights = weights) + end + + pc +end + """ reweight!(pc::ParticleContainer) @@ -249,19 +263,18 @@ function reweight!(pc::ParticleContainer) # the execution of the model is finished. # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and # ``θᵢ`` are variables of other samplers. - score = Libtask.consume(p) + score = advance!(p) if score === nothing numdone += 1 else - # Increase the unnormalized logarithmic weights, accounting for the variables - # of other samplers. - increase_logweight!(pc, i, score + getlogp(p.vi)) + # Increase the unnormalized logarithmic weights. + increase_logweight!(pc, i, score) # Reset the accumulator of the log probability in the model so that we can # accumulate log probabilities of variables of other samplers until the next # observation. - resetlogp!(p.vi) + reset_logprob!(p) end end @@ -333,19 +346,3 @@ function sweep!(pc::ParticleContainer, resampler) return logevidence end - -function resample_propagate!( - pc::ParticleContainer, - resampler::ResampleWithESSThreshold, - ref::Union{Particle,Nothing} = nothing; - weights = getweights(pc) -) - # Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ`` - ess = inv(sum(abs2, weights)) - - if ess ≤ resampler.threshold * length(pc) - resample_propagate!(pc, resampler.resampler, ref; weights = weights) - end - - pc -end diff --git a/src/resampling.jl b/src/resampling.jl index 1cb0ef40..964ef5ee 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -16,11 +16,6 @@ function ResampleWithESSThreshold(resampler = resample) ResampleWithESSThreshold(resampler, 0.5) end -# Default resampling scheme -function resample(w::AbstractVector{<:Real}, num_particles::Integer=length(w)) - return resample_systematic(w, num_particles) -end - # More stable, faster version of rand(Categorical) function randcat(p::AbstractVector{<:Real}) T = eltype(p) @@ -36,11 +31,17 @@ function randcat(p::AbstractVector{<:Real}) return s end -function resample_multinomial(w::AbstractVector{<:Real}, num_particles::Integer) +function resample_multinomial( + w::AbstractVector{<:Real}, + num_particles::Integer = length(w), +) return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles) end -function resample_residual(w::AbstractVector{<:Real}, num_particles::Integer) +function resample_residual( + w::AbstractVector{<:Real}, + num_particles::Integer = length(weights), +) # Pre-allocate array for resampled particles indices = Vector{Int}(undef, num_particles) @@ -79,7 +80,7 @@ are selected according to the multinomial distribution defined by the normalized i.e., `xᵢ = j` if and only if ``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. """ -function resample_stratified(weights::AbstractVector{<:Real}, n::Integer) +function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights)) # check input m = length(weights) m > 0 || error("weight vector is empty") @@ -124,7 +125,7 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n normalized `weights`, i.e., `xᵢ = j` if and only if ``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``. """ -function resample_systematic(weights::AbstractVector{<:Real}, n::Integer) +function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights)) # check input m = length(weights) m > 0 || error("weight vector is empty") @@ -157,3 +158,6 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer) return samples end + +# Default resampling scheme +const resample = resample_systematic \ No newline at end of file From 241e9b5a3cca0b8343da3438f2557bb7df048450 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 3 Dec 2020 00:19:46 +0100 Subject: [PATCH 2/2] Update tests --- test/Project.toml | 4 +- test/container.jl | 107 ++++++++++++++++++---------------------------- test/runtests.jl | 2 + 3 files changed, 46 insertions(+), 67 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 73d73977..87216cb8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,7 @@ [deps] +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -julia = "1.3" \ No newline at end of file +Libtask = "0.5" +julia = "1.3" diff --git a/test/container.jl b/test/container.jl index 6731e1e0..6a4b67a2 100644 --- a/test/container.jl +++ b/test/container.jl @@ -1,25 +1,14 @@ -using Turing, Random -using Turing: ParticleContainer, getweights, getweight, - effectiveSampleSize, Trace, current_trace, VarName, VarInfo, - Sampler, consume, produce, fork, getlogp -using Turing.Core: logZ, reset_logweights!, increase_logweight!, - resample_propagate!, reweight! -using Test - -dir = splitdir(splitdir(pathof(Turing))[1])[1] -include(dir*"/test/test_utils/AllUtils.jl") - @testset "container.jl" begin - @turing_testset "copy particle container" begin - pc = ParticleContainer(Trace[]) + @testset "copy particle container" begin + pc = AdvancedPS.ParticleContainer(AdvancedPS.Trace[]) newpc = copy(pc) @test newpc.logWs == pc.logWs @test typeof(pc) === typeof(newpc) end - @turing_testset "particle container" begin - # Create a resumable function that always yields `logp`. + @testset "particle container" begin + # Create a resumable function that always returns the same log probability. function fpc(logp) f = let logp = logp () -> begin @@ -31,97 +20,83 @@ include(dir*"/test/test_utils/AllUtils.jl") return f end - # Dummy sampler that is not actually used. - sampler = Sampler(PG(5), empty_model()) - # Create particle container. logps = [0.0, -1.0, -2.0] - particles = [Trace(fpc(logp), empty_model(), sampler, VarInfo()) for logp in logps] - pc = ParticleContainer(particles) + particles = [AdvancedPS.Trace(fpc(logp)) for logp in logps] + pc = AdvancedPS.ParticleContainer(particles) # Initial state. - @test all(iszero(getlogp(particle.vi)) for particle in pc.vals) @test pc.logWs == zeros(3) - @test getweights(pc) == fill(1/3, 3) - @test all(getweight(pc, i) == 1/3 for i in 1:3) - @test logZ(pc) ≈ log(3) - @test effectiveSampleSize(pc) == 3 + @test AdvancedPS.getweights(pc) == fill(1/3, 3) + @test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3) + @test AdvancedPS.logZ(pc) ≈ log(3) + @test AdvancedPS.effectiveSampleSize(pc) == 3 # Reweight particles. - reweight!(pc) - @test all(iszero(getlogp(particle.vi)) for particle in pc.vals) + AdvancedPS.reweight!(pc) @test pc.logWs == logps - @test getweights(pc) ≈ exp.(logps) ./ sum(exp, logps) - @test all(getweight(pc, i) ≈ exp(logps[i]) / sum(exp, logps) for i in 1:3) - @test logZ(pc) ≈ log(sum(exp, logps)) + @test AdvancedPS.getweights(pc) ≈ exp.(logps) ./ sum(exp, logps) + @test all(AdvancedPS.getweight(pc, i) ≈ exp(logps[i]) / sum(exp, logps) for i in 1:3) + @test AdvancedPS.logZ(pc) ≈ log(sum(exp, logps)) # Reweight particles. - reweight!(pc) - @test all(iszero(getlogp(particle.vi)) for particle in pc.vals) + AdvancedPS.reweight!(pc) @test pc.logWs == 2 .* logps - @test getweights(pc) == exp.(2 .* logps) ./ sum(exp, 2 .* logps) - @test all(getweight(pc, i) ≈ exp(2 * logps[i]) / sum(exp, 2 .* logps) for i in 1:3) - @test logZ(pc) ≈ log(sum(exp, 2 .* logps)) + @test AdvancedPS.getweights(pc) == exp.(2 .* logps) ./ sum(exp, 2 .* logps) + @test all(AdvancedPS.getweight(pc, i) ≈ exp(2 * logps[i]) / sum(exp, 2 .* logps) for i in 1:3) + @test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps)) # Resample and propagate particles. - resample_propagate!(pc) - @test all(iszero(getlogp(particle.vi)) for particle in pc.vals) + AdvancedPS.resample_propagate!(pc) @test pc.logWs == zeros(3) - @test getweights(pc) == fill(1/3, 3) - @test all(getweight(pc, i) == 1/3 for i in 1:3) - @test logZ(pc) ≈ log(3) - @test effectiveSampleSize(pc) == 3 + @test AdvancedPS.getweights(pc) == fill(1/3, 3) + @test all(AdvancedPS.getweight(pc, i) == 1/3 for i in 1:3) + @test AdvancedPS.logZ(pc) ≈ log(3) + @test AdvancedPS.effectiveSampleSize(pc) == 3 # Reweight particles. - reweight!(pc) - @test all(iszero(getlogp(particle.vi)) for particle in pc.vals) + AdvancedPS.reweight!(pc) @test pc.logWs ⊆ logps - @test getweights(pc) == exp.(pc.logWs) ./ sum(exp, pc.logWs) - @test all(getweight(pc, i) ≈ exp(pc.logWs[i]) / sum(exp, pc.logWs) for i in 1:3) - @test logZ(pc) ≈ log(sum(exp, pc.logWs)) + @test AdvancedPS.getweights(pc) == exp.(pc.logWs) ./ sum(exp, pc.logWs) + @test all(AdvancedPS.getweight(pc, i) ≈ exp(pc.logWs[i]) / sum(exp, pc.logWs) for i in 1:3) + @test AdvancedPS.logZ(pc) ≈ log(sum(exp, pc.logWs)) # Increase unnormalized logarithmic weights. logws = copy(pc.logWs) - increase_logweight!(pc, 2, 1.41) + AdvancedPS.increase_logweight!(pc, 2, 1.41) @test pc.logWs == logws + [0, 1.41, 0] # Reset unnormalized logarithmic weights. logws = pc.logWs - reset_logweights!(pc) + AdvancedPS.reset_logweights!(pc) @test pc.logWs === logws @test all(iszero, pc.logWs) end - @turing_testset "trace" begin + @testset "trace" begin n = Ref(0) - - alg = PG(5) - spl = Sampler(alg, empty_model()) - dist = Normal(0, 1) function f2() - t = TArray(Int, 1); - t[1] = 0; + t = TArray(Int, 1) + t[1] = 0 while true - ct = current_trace() - vn = @varname x[n] - Turing.assume(Random.GLOBAL_RNG, spl, dist, vn, ct.vi) n[] += 1 produce(t[1]) - vn = @varname x[n] - Turing.assume(Random.GLOBAL_RNG, spl, dist, vn, ct.vi) n[] += 1 t[1] = 1 + t[1] end end # Test task copy version of trace - tr = Trace(f2, empty_model(), spl, VarInfo()) + tr = AdvancedPS.Trace(f2) + + consume(tr.ctask) + consume(tr.ctask) - consume(tr); consume(tr) - a = fork(tr); - consume(a); consume(a) + a = AdvancedPS.fork(tr) + consume(a.ctask) + consume(a.ctask) - @test consume(tr) == 2 - @test consume(a) == 4 + @test consume(tr.ctask) == 2 + @test consume(a.ctask) == 4 end end diff --git a/test/runtests.jl b/test/runtests.jl index 067506a3..f01dd2c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,8 @@ using AdvancedPS +using Libtask using Test @testset "AdvancedPS.jl" begin @testset "Resampling tests" begin include("resampling.jl") end + @testset "Container tests" begin include("container.jl") end end