Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the update_ref! API #85

Merged
merged 10 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 16 additions & 35 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using Distributions
using Plots
using AbstractMCMC
using Random123
using Libtask

"""
plot_update_rate(update_rate, N)
Expand Down Expand Up @@ -91,39 +90,35 @@ plot(x; label="x", xlabel="t")
plot(y; label="y", xlabel="t")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
X::Array
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's consider adding SSMProblems to the dependency and follow its interface. We can also remove the out-of-date SSM interface from AdvancedPS.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, but maybe in another PR ? I think there's some work needed around the way we handle the observations/data here compared to SSMProblems

X::Vector{Float64}
θ::Parameters
NonLinearTimeSeries(θ::Parameters) = new(zeros(Float64, θ.T), θ)
NonLinearTimeSeries(θ::Parameters) = new(Float64[], θ)
end

function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
x₀ = rand(rng, f₀(model.θ))
model.X[1] = x₀
score = logpdf(g(model.θ, x₀, 1), y[1])
Libtask.produce(score)

for t in 2:(model.θ.T)
state = rand(rng, f(model.θ, model.X[t - 1], t - 1))
model.X[t] = state
score = logpdf(g(model.θ, state, t), y[t])
Libtask.produce(score)
end
# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2);
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand All @@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearSSM(θ::Parameters) = new(Float64[], θ)
end

AdvancedPS.initialization(model::NonLinearSSM) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearSSM, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearSSM, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearSSM, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pgas = AdvancedPS.PGAS(Nₚ)
model = NonLinearSSM(θ₀)
chains = sample(rng, model, pgas, Nₛ; progress=false);
particles = hcat([chain.trajectory.model.X for chain in chains]...);
mean_trajectory = mean(particles; dims=2);

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand Down
4 changes: 2 additions & 2 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, reference)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(rng, particles)
Expand Down Expand Up @@ -184,7 +184,7 @@ function AbstractMCMC.sample(
particles = AdvancedPS.ParticleContainer(traces, AdvancedPS.TracedRNG(), rng)

# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler)

replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals)

Expand Down
4 changes: 3 additions & 1 deletion src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ using Random123: Random123

abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end

abstract type ParticleSampler <: AbstractMCMC.AbstractSampler end

""" Abstract type for an abstract model formulated in the state space form
"""
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
Expand All @@ -17,8 +19,8 @@ include("resampling.jl")
include("rng.jl")
include("model.jl")
include("container.jl")
include("pgas.jl")
include("smc.jl")
include("pgas.jl")

if !isdefined(Base, :get_extension)
using Requires
Expand Down
15 changes: 10 additions & 5 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ end

Update reference trajectory. Defaults to `nothing`
"""
update_ref!(particle::Trace, pc::ParticleContainer) = nothing
function update_ref!(particle::Trace, pc::ParticleContainer, sampler::ParticleSampler)
return nothing
end

"""
reset_logweights!(pc::ParticleContainer)
Expand Down Expand Up @@ -167,6 +169,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
function resample_propagate!(
::Random.AbstractRNG,
pc::ParticleContainer,
sampler::ParticleSampler,
randcat=DEFAULT_RESAMPLER,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand Down Expand Up @@ -214,7 +217,7 @@ function resample_propagate!(
if ref !== nothing
# Insert the retained particle. This is based on the replaying trick for efficiency
# reasons. If we implement PG using task copying, we need to store Nx * T particles!
update_ref!(ref, pc)
update_ref!(ref, pc, sampler)
@inbounds children[n] = ref
end

Expand All @@ -228,6 +231,7 @@ end
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractMCMC.AbstractSampler,
yebai marked this conversation as resolved.
Show resolved Hide resolved
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand All @@ -236,7 +240,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess ≤ resampler.threshold * length(pc)
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
resample_propagate!(rng, pc, sampler, resampler.resampler, ref; weights=weights)
else
update_keys!(pc, ref)
end
Expand Down Expand Up @@ -311,11 +315,12 @@ function sweep!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler,
sampler::AbstractMCMC.AbstractSampler,
ref::Union{Particle,Nothing}=nothing,
)
# Initial step:
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -336,7 +341,7 @@ function sweep!(
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand Down
2 changes: 1 addition & 1 deletion src/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ function forkr(particle::SSMTrace)
return newtrace
end

function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace})
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS)
current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1
isdone(ref.model, current_step(ref)) && return nothing

Expand Down
12 changes: 6 additions & 6 deletions src/smc.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
struct SMC{R} <: AbstractMCMC.AbstractSampler
struct SMC{R} <: ParticleSampler
nparticles::Int
resampler::R
end
Expand Down Expand Up @@ -46,12 +46,12 @@ function AbstractMCMC.sample(
particles = ParticleContainer(traces, TracedRNG(), rng)

# Perform particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler)
logevidence = sweep!(rng, particles, sampler.resampler, sampler)

return SMCSample(collect(particles), getweights(particles), logevidence)
end

struct PG{R} <: AbstractMCMC.AbstractSampler
struct PG{R} <: ParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand Down Expand Up @@ -84,7 +84,7 @@ struct PGSample{T,L}
logevidence::L
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
struct PGAS{R} <: ParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand All @@ -96,7 +96,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractStateSpaceModel,
sampler::PGAS,
sampler::Union{PGAS,PG},
state::Union{PGState,Nothing}=nothing;
kwargs...,
)
Expand All @@ -116,7 +116,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = sweep!(rng, particles, sampler.resampler, reference)
logevidence = sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(particles.rng, particles)
Expand Down
5 changes: 3 additions & 2 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@
ref = AdvancedPS.forkr(selected)
pc_ref.vals[end] = ref

sampler = AdvancedPS.PG(length(logps))
AdvancedPS.resample_propagate!(
Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, ref
Random.GLOBAL_RNG, pc_ref, sampler, AdvancedPS.resample_systematic, ref
)
@test pc_ref.logWs == zeros(3)
@test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3)
Expand All @@ -84,7 +85,7 @@
@test pc_ref.vals[end] === particles_ref[end]

# Resample and propagate particles.
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc, sampler)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)
Expand Down
5 changes: 3 additions & 2 deletions test/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
AdvancedPS.Trace(BaseModel(Params(0.9, 0.31, 1)), AdvancedPS.TracedRNG()) for
_ in 1:3
]
sampler = AdvancedPS.PGAS(3)
resampler = AdvancedPS.ResampleWithESSThreshold(1.0)

part = particles[3]
Expand All @@ -58,11 +59,11 @@
pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG(), base_rng)

AdvancedPS.reweight!(pc, ref)
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
pc.logWs = [-Inf, 0, -Inf] # Force ancestor update to second particle
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
@test all(pc.vals[2].model.X[1:2] .≈ ref.model.X[1:2])
Expand Down
Loading