Skip to content

Commit

Permalink
Merge branch 'master' into fred/turing
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Nov 3, 2023
2 parents 118e58e + b157064 commit a9f0220
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 58 deletions.
5 changes: 2 additions & 3 deletions Project.toml
@@ -1,12 +1,11 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.5"
version = "0.5.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -19,7 +18,7 @@ Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
AdvancedPSLibtaskExt = "Libtask"

[compat]
AbstractMCMC = "2, 3, 4"
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Random123 = "1.3"
Expand Down
51 changes: 16 additions & 35 deletions examples/particle-gibbs/script.jl
Expand Up @@ -5,7 +5,6 @@ using Distributions
using Plots
using AbstractMCMC
using Random123
using Libtask

"""
plot_update_rate(update_rate, N)
Expand Down Expand Up @@ -91,39 +90,35 @@ plot(x; label="x", xlabel="t")
plot(y; label="y", xlabel="t")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractGenericModel
X::Array
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
NonLinearTimeSeries::Parameters) = new(Float64[], θ)
end

function (model::NonLinearTimeSeries)(rng::Random.AbstractRNG)
x₀ = rand(rng, f₀(model.θ))
model.X[1] = x₀
score = logpdf(g(model.θ, x₀, 1), y[1])
Libtask.produce(score)

for t in 2:(model.θ.T)
state = rand(rng, f(model.θ, model.X[t - 1], t - 1))
model.X[t] = state
score = logpdf(g(model.θ, state, t), y[t])
Libtask.produce(score)
end
# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

particles = hcat([chain.trajectory.X for chain in chains]...) # Concat all sampled states
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2);
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=1.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand All @@ -133,29 +128,15 @@ plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
mutable struct NonLinearSSM <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Float64}
θ::Parameters
NonLinearSSM::Parameters) = new(Float64[], θ)
end

AdvancedPS.initialization(model::NonLinearSSM) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearSSM, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearSSM, state, step)
return logpdf(g(model.θ, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearSSM, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pgas = AdvancedPS.PGAS(Nₚ)
model = NonLinearSSM(θ₀)
chains = sample(rng, model, pgas, Nₛ; progress=false);
particles = hcat([chain.trajectory.model.X for chain in chains]...);
mean_trajectory = mean(particles; dims=2);

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

Expand Down
4 changes: 2 additions & 2 deletions ext/AdvancedPSLibtaskExt.jl
Expand Up @@ -143,7 +143,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, reference)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(rng, particles)
Expand Down Expand Up @@ -180,7 +180,7 @@ function AbstractMCMC.sample(
particles = AdvancedPS.ParticleContainer(traces, AdvancedPS.TracedRNG(), rng)

# Perform particle sweep.
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler)
logevidence = AdvancedPS.sweep!(rng, particles, sampler.resampler, sampler)

replayed = map(particle -> AdvancedPS.replay(particle).model.f, particles.vals)

Expand Down
4 changes: 3 additions & 1 deletion src/AdvancedPS.jl
Expand Up @@ -8,6 +8,8 @@ using Random123: Random123

abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end

abstract type AbstractParticleSampler <: AbstractMCMC.AbstractSampler end

""" Abstract type for an abstract model formulated in the state space form
"""
abstract type AbstractStateSpaceModel <: AbstractParticleModel end
Expand All @@ -17,8 +19,8 @@ include("resampling.jl")
include("rng.jl")
include("model.jl")
include("container.jl")
include("pgas.jl")
include("smc.jl")
include("pgas.jl")

if !isdefined(Base, :get_extension)
using Requires
Expand Down
17 changes: 12 additions & 5 deletions src/container.jl
Expand Up @@ -61,7 +61,11 @@ end
Update reference trajectory. Defaults to `nothing`
"""
update_ref!(particle::Trace, pc::ParticleContainer) = nothing
function update_ref!(
particle::Trace, pc::ParticleContainer, sampler::AbstractParticleSampler
)
return nothing
end

"""
reset_logweights!(pc::ParticleContainer)
Expand Down Expand Up @@ -167,6 +171,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
function resample_propagate!(
::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
randcat=DEFAULT_RESAMPLER,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand Down Expand Up @@ -214,7 +219,7 @@ function resample_propagate!(
if ref !== nothing
# Insert the retained particle. This is based on the replaying trick for efficiency
# reasons. If we implement PG using task copying, we need to store Nx * T particles!
update_ref!(ref, pc)
update_ref!(ref, pc, sampler)
@inbounds children[n] = ref
end

Expand All @@ -228,6 +233,7 @@ end
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
Expand All @@ -236,7 +242,7 @@ function resample_propagate!(
ess = inv(sum(abs2, weights))

if ess resampler.threshold * length(pc)
resample_propagate!(rng, pc, resampler.resampler, ref; weights=weights)
resample_propagate!(rng, pc, sampler, resampler.resampler, ref; weights=weights)
else
update_keys!(pc, ref)
end
Expand Down Expand Up @@ -311,11 +317,12 @@ function sweep!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
resampler,
sampler::AbstractMCMC.AbstractSampler,
ref::Union{Particle,Nothing}=nothing,
)
# Initial step:
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand All @@ -336,7 +343,7 @@ function sweep!(
# For observations ``y₂, …, yₜ``:
while !isdone
# Resample and propagate particles.
resample_propagate!(rng, pc, resampler, ref)
resample_propagate!(rng, pc, sampler, resampler, ref)

# Compute the current normalizing constant ``Z₀`` of the unnormalized logarithmic
# weights.
Expand Down
2 changes: 1 addition & 1 deletion src/pgas.jl
Expand Up @@ -133,7 +133,7 @@ function forkr(particle::SSMTrace)
return newtrace
end

function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace})
function update_ref!(ref::SSMTrace, pc::ParticleContainer{<:SSMTrace}, sampler::PGAS)
current_step(ref) <= 2 && return nothing # At the beginning of step + 1 since we start at 1
isdone(ref.model, current_step(ref)) && return nothing

Expand Down
12 changes: 6 additions & 6 deletions src/smc.jl
@@ -1,4 +1,4 @@
struct SMC{R} <: AbstractMCMC.AbstractSampler
struct SMC{R} <: AbstractParticleSampler
nparticles::Int
resampler::R
end
Expand Down Expand Up @@ -46,12 +46,12 @@ function AbstractMCMC.sample(
particles = ParticleContainer(traces, TracedRNG(), rng)

# Perform particle sweep.
logevidence = sweep!(rng, particles, sampler.resampler)
logevidence = sweep!(rng, particles, sampler.resampler, sampler)

return SMCSample(collect(particles), getweights(particles), logevidence)
end

struct PG{R} <: AbstractMCMC.AbstractSampler
struct PG{R} <: AbstractParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand Down Expand Up @@ -84,7 +84,7 @@ struct PGSample{T,L}
logevidence::L
end

struct PGAS{R} <: AbstractMCMC.AbstractSampler
struct PGAS{R} <: AbstractParticleSampler
"""Number of particles."""
nparticles::Int
"""Resampling algorithm."""
Expand All @@ -96,7 +96,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0))
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::AbstractStateSpaceModel,
sampler::PGAS,
sampler::Union{PGAS,PG},
state::Union{PGState,Nothing}=nothing;
kwargs...,
)
Expand All @@ -116,7 +116,7 @@ function AbstractMCMC.step(

# Perform a particle sweep.
reference = isref ? particles.vals[nparticles] : nothing
logevidence = sweep!(rng, particles, sampler.resampler, reference)
logevidence = sweep!(rng, particles, sampler.resampler, sampler, reference)

# Pick a particle to be retained.
newtrajectory = rand(particles.rng, particles)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Expand Up @@ -7,7 +7,7 @@ Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
AbstractMCMC = "2, 3, 4"
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.24, 0.25"
Libtask = "0.8"
Random123 = "1.3"
Expand Down
5 changes: 3 additions & 2 deletions test/container.jl
Expand Up @@ -73,8 +73,9 @@
ref = AdvancedPS.forkr(selected)
pc_ref.vals[end] = ref

sampler = AdvancedPS.PG(length(logps))
AdvancedPS.resample_propagate!(
Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, ref
Random.GLOBAL_RNG, pc_ref, sampler, AdvancedPS.resample_systematic, ref
)
@test pc_ref.logWs == zeros(3)
@test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3)
Expand All @@ -84,7 +85,7 @@
@test pc_ref.vals[end] === particles_ref[end]

# Resample and propagate particles.
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc, sampler)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)
Expand Down
5 changes: 3 additions & 2 deletions test/pgas.jl
Expand Up @@ -46,6 +46,7 @@
AdvancedPS.Trace(BaseModel(Params(0.9, 0.31, 1)), AdvancedPS.TracedRNG()) for
_ in 1:3
]
sampler = AdvancedPS.PGAS(3)
resampler = AdvancedPS.ResampleWithESSThreshold(1.0)

part = particles[3]
Expand All @@ -58,11 +59,11 @@
pc = AdvancedPS.ParticleContainer(particles, AdvancedPS.TracedRNG(), base_rng)

AdvancedPS.reweight!(pc, ref)
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
pc.logWs = [-Inf, 0, -Inf] # Force ancestor update to second particle
AdvancedPS.resample_propagate!(base_rng, pc, resampler, ref)
AdvancedPS.resample_propagate!(base_rng, pc, sampler, resampler, ref)

AdvancedPS.reweight!(pc, ref)
@test all(pc.vals[2].model.X[1:2] .≈ ref.model.X[1:2])
Expand Down

0 comments on commit a9f0220

Please sign in to comment.