Skip to content

Commit

Permalink
Update the update_ref! API (#85)
Browse files Browse the repository at this point in the history
* Try to fix the API, `update_ref!` is broken

* Introduce a `ParticleSampler` type (#86)

* Update smc.jl

* Update smc.jl

* Apply suggestions from code review

* Update smc.jl

* Update AdvancedPS.jl

* Update src/container.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/AdvancedPS.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Apply suggestions from code review

* fix typo

* Changing the name

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 23, 2023
1 parent e545c86 commit 67d3496
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 54 deletions.
51 changes: 16 additions & 35 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Distributions
using Plots
using AbstractMCMC
using Random123
using Libtask

"""
plot_update_rate(update_rate, N)
Expand Down Expand Up @@ -91,39 +90,35 @@ plot(x; label="x", xlabel="t")
plot(y; label="y", xlabel="t")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
X::Array
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
NonLinearTimeSeries::Parameters) = new(Float64[], θ)
end

function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
x₀ = rand(rng, f₀(model.θ))
model.X[1] = x₀
score = logpdf(g(model.θ, x₀, 1), y[1])
Libtask.produce(score)

for t in 2:(model.θ.T)
state = rand(rng, f(model.θ, model.X[t - 1], t - 1))
model.X[t] = state
score = logpdf(g(model.θ, state, t), y[t])
Libtask.produce(score)
end
# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2);
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand All @@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearSSM::Parameters) = new(Float64[], θ)
end

AdvancedPS.initialization(model::NonLinearSSM) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearSSM, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearSSM, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearSSM, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pgas = AdvancedPS.PGAS(Nₚ)
model = NonLinearSSM(θ₀)
chains = sample(rng, model, pgas, Nₛ; progress=false);
particles = hcat([chain.trajectory.model.X for chain in chains]...);
mean_trajectory = mean(particles; dims=2);

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand Down
4 changes: 2 additions & 2 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, reference)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(rng, particles)
Expand Down Expand Up @@ -184,7 +184,7 @@ function AbstractMCMC.sample(
particles = AdvancedPS.ParticleContainer(traces, AdvancedPS.TracedRNG(), rng)

# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler)

replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals)

Expand Down
4 changes: 3 additions & 1 deletion src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using Random123: Random123

abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end

abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end

""" Abstract type for an abstract model formulated in the state space form
"""
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
Expand All @@ -17,8 +19,8 @@ include("resampling.jl")
include("rng.jl")
include("model.jl")
include("container.jl")
include("pgas.jl")
include("smc.jl")
include("pgas.jl")

if !isdefined(Base, :get_extension)
using Requires
Expand Down
17 changes: 12 additions & 5 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ end
Update reference trajectory. Defaults to `nothing`
"""
update_ref!(particle::Trace, pc::ParticleContainer) = nothing
function update_ref!(
particle::Trace, pc::ParticleContainer, sampler::AbstractParticleSampler
)
return nothing
end

"""
reset_logweights!(pc::ParticleContainer)
Expand Down Expand Up @@ -167,6 +171,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
function resample_propagate!(
::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
randcat=DEFAULT_RESAMPLER,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand Down Expand Up @@ -214,7 +219,7 @@ function resample_propagate!(
if ref !== nothing
# Insert the retained particle. This is based on the replaying trick for efficiency
# reasons. If we implement PG using task copying, we need to store Nx * T particles!
update_ref!(ref, pc)
update_ref!(ref, pc, sampler)
@inbounds children[n] = ref
end

Expand All @@ -228,6 +233,7 @@ end
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand All @@ -236,7 +242,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
resample_propagate!(rng, pc, sampler, resampler.resampler, ref; weights=weights)
else
update_keys!(pc, ref)
end
Expand Down Expand Up @@ -311,11 +317,12 @@ function sweep!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler,
sampler::AbstractMCMC.AbstractSampler,
ref::Union{Particle,Nothing}=nothing,
)
# Initial step:
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -336,7 +343,7 @@ function sweep!(
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand Down
2 changes: 1 addition & 1 deletion src/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function forkr(particle::SSMTrace)
return newtrace
end

function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace})
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS)
current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1
isdone(ref.model, current_step(ref)) && return nothing

Expand Down
12 changes: 6 additions & 6 deletions src/smc.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct SMC{R} <: AbstractMCMC.AbstractSampler
struct SMC{R} <: AbstractParticleSampler
nparticles::Int
resampler::R
end
Expand Down Expand Up @@ -46,12 +46,12 @@ function AbstractMCMC.sample(
particles = ParticleContainer(traces, TracedRNG(), rng)

# Perform particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler)
logevidence = sweep!(rng, particles, sampler.resampler, sampler)

return SMCSample(collect(particles), getweights(particles), logevidence)
end

struct PG{R} <: AbstractMCMC.AbstractSampler
struct PG{R} <: AbstractParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand Down Expand Up @@ -84,7 +84,7 @@ struct PGSample{T,L}
logevidence::L
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
struct PGAS{R} <: AbstractParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand All @@ -96,7 +96,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractStateSpaceModel,
sampler::PGAS,
sampler::Union{PGAS,PG},
state::Union{PGState,Nothing}=nothing;
kwargs...,
)
Expand All @@ -116,7 +116,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = sweep!(rng, particles, sampler.resampler, reference)
logevidence = sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(particles.rng, particles)
Expand Down
5 changes: 3 additions & 2 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@
ref = AdvancedPS.forkr(selected)
pc_ref.vals[end] = ref

sampler = AdvancedPS.PG(length(logps))
AdvancedPS.resample_propagate!(
Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, ref
Random.GLOBAL_RNG, pc_ref, sampler, AdvancedPS.resample_systematic, ref
)
@test pc_ref.logWs == zeros(3)
@test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3)
Expand All @@ -84,7 +85,7 @@
@test pc_ref.vals[end] === particles_ref[end]

# Resample and propagate particles.
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc, sampler)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)
Expand Down
5 changes: 3 additions & 2 deletions test/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AdvancedPS.Trace(BaseModel(Params(0.9, 0.31, 1)), AdvancedPS.TracedRNG()) for
_ in 1:3
]
sampler = AdvancedPS.PGAS(3)
resampler = AdvancedPS.ResampleWithESSThreshold(1.0)

part = particles[3]
Expand All @@ -58,11 +59,11 @@
pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG(), base_rng)

AdvancedPS.reweight!(pc, ref)
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
pc.logWs = [-Inf, 0, -Inf] # Force ancestor update to second particle
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
@test all(pc.vals[2].model.X[1:2] .≈ ref.model.X[1:2])
Expand Down

2 comments on commit 67d3496

@yebai
Copy link
Member

@yebai yebai commented on 67d3496 Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/93978

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.0 -m "<description of version>" 67d34964418c7a69320fecba1739a7510b935a34
git push origin v0.5.0

Please sign in to comment.