Skip to content

Commit

Permalink
Update script.jl (#79)
Browse files Browse the repository at this point in the history
* Update script.jl

Fixed some typos in the model and one in the code (variance was used instead of standard deviation in measurement model).
Added PGAS sampling at the end to show that it solves the degeneracy problem, which should close the issue #77

* Update script.jl

added back the rng argument when sampling with PG.

* Update examples/particle-gibbs/script.jl

Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com>

---------

Co-authored-by: FredericWantiez <frederic.wantiez@gmail.com>
  • Loading branch information
mattiasvillani and FredericWantiez committed Sep 17, 2023
1 parent d785ce5 commit 72a1e55
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions examples/particle-gibbs/script.jl
@@ -1,4 +1,3 @@
# # Particle Gibbs for non-linear models
using AdvancedPS
using Random
using Distributions
Expand All @@ -13,34 +12,33 @@ using Libtask
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad v_{t} \sim \mathcal{N}(0, 1)
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1)
# ```
#
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, q^2)
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2)
# ```
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
Parameters = @NamedTuple begin
a::Float64
q::Float64
T::Int
end

mutable struct NonLinearTimeSeries <: AbstractMCMC.AbstractModel
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
X::Array
θ::Parameters
NonLinearTimeSeries::Parameters) = new(zeros(Float64, θ.T), θ)
end

f(model::NonLinearTimeSeries, state, t) = Normal(model.θ.a * state, model.θ.q)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state)^2)
g(model::NonLinearTimeSeries, state, t) = Normal(0, exp(0.5 * state))
f₀(model::NonLinearTimeSeries) = Normal(0, model.θ.q)
#md nothing #hide

# Let's simulate some data
a = 0.9 # State Variance
Expand Down Expand Up @@ -88,8 +86,8 @@ end

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
pgas = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pgas, Nₛ; progress=false);
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide

# The trajectories are not stored during the sampling and we need to regenerate the history of each
Expand All @@ -116,8 +114,7 @@ mean_trajectory = mean(particles; dims=2)
#md nothing #hide

# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help
# with the degeneracy problem.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
Expand All @@ -137,3 +134,35 @@ plot(
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")

# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
# To use this sampler we need to define the transition and observation densities as well as the initial distribution in the following way:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model, state, step), y[step])
end
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# We can now sample from the model using the PGAS sampler and collect the trajectories.
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(model, pg, Nₛ);
particles = hcat([trajectory.model.f.X for trajectory in trajectories]...)
mean_trajectory = mean(particles; dims=2)

# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(particles; label=false, opacity=0.01, color=:black, xlabel="t", ylabel="state")
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)

# The update rate is now much higher throughout time.
update_rate = sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
plot(
update_rate;
label=false,
ylim=[0, 1],
legend=:bottomleft,
xlabel="Iteration",
ylabel="Update rate",
)
hline!([1 - 1 / Nₚ]; label="N: $(Nₚ)")

0 comments on commit 72a1e55

Please sign in to comment.