diff --git a/src/container.jl b/src/container.jl index bb0fe76f..504bc329 100644 --- a/src/container.jl +++ b/src/container.jl @@ -314,11 +314,16 @@ The resampling steps use the given `resampler`. Del Moral, P., Doucet, A., & Jasra, A. (2006). Sequential monte carlo samplers. Journal of the Royal Statistical Society: Series B (Statistical Methodology), 68(3), 411-436. """ -function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler) +function sweep!( + rng::Random.AbstractRNG, + pc::ParticleContainer, + resampler, + ref::Union{Particle,Nothing}=nothing, +) # Initial step: # Resample and propagate particles. - resample_propagate!(rng, pc, resampler) + resample_propagate!(rng, pc, resampler, ref) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. @@ -339,7 +344,7 @@ function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler) # For observations ``y₂, …, yₜ``: while !isdone # Resample and propagate particles. - resample_propagate!(rng, pc, resampler) + resample_propagate!(rng, pc, resampler, ref) # Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic # weights. diff --git a/src/smc.jl b/src/smc.jl index 1beab676..89d465c8 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -114,7 +114,7 @@ function AbstractMCMC.step( particles = ParticleContainer(x) # Perform a particle sweep. - logevidence = sweep!(rng, particles, sampler.resampler) + logevidence = sweep!(rng, particles, sampler.resampler, particles.vals[nparticles]) # Pick a particle to be retained. newtrajectory = rand(rng, particles) diff --git a/test/container.jl b/test/container.jl index 7313ac79..69c524e5 100644 --- a/test/container.jl +++ b/test/container.jl @@ -51,6 +51,19 @@ ) @test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps)) + # Resample and propagate particles with reference particle + particles_ref = [AdvancedPS.Trace(fpc(logp)) for logp in logps] + pc_ref = AdvancedPS.ParticleContainer(particles_ref) + AdvancedPS.resample_propagate!( + Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, particles_ref[end] + ) + @test pc_ref.logWs == zeros(3) + @test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3) + @test all(AdvancedPS.getweight(pc_ref, i) == 1 / 3 for i in 1:3) + @test AdvancedPS.logZ(pc_ref) ≈ log(3) + @test AdvancedPS.effectiveSampleSize(pc_ref) == 3 + @test pc_ref.vals[end] === particles_ref[end] + # Resample and propagate particles. AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc) @test pc.logWs == zeros(3)