-
Notifications
You must be signed in to change notification settings - Fork 9
/
script.jl
164 lines (142 loc) · 5.43 KB
/
script.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# # Particle Gibbs for non-linear models
using AdvancedPS
using Random
using Distributions
using Plots
using AbstractMCMC
using Random123
using SSMProblems
"""
plot_update_rate(update_rate, N)
Plot empirical update rate against theoretical value
"""
function plot_update_rate(update_rate::AbstractVector{Float64}, Nₚ::Int)
plt = plot(
update_rate;
label=false,
ylim=[0, 1],
legend=:bottomleft,
xlabel="Iteration",
ylabel="Update rate",
)
return hline!(plt, [1 - 1 / Nₚ]; label="N: $(Nₚ)")
end
"""
update_rate(trajectories, N)
Compute latent state update rate
"""
function update_rate(particles::AbstractMatrix{Float64}, Nₛ)
return sum(abs.(diff(particles; dims=2)) .> 0; dims=2) / Nₛ
end
#md nothing #hide
# We consider the following stochastic volatility model:
#
# ```math
# x_{t+1} = a x_t + v_t \quad v_{t} \sim \mathcal{N}(0, r^2)
# ```
# ```math
# y_{t} = e_t \exp(\frac{1}{2}x_t) \quad e_t \sim \mathcal{N}(0, 1)
# ```
#
# We can reformulate the above in terms of transition and observation densities:
# ```math
# x_{t+1} \sim f_{\theta}(x_{t+1}|x_t) = \mathcal{N}(a x_t, r^2)
# ```
# ```math
# y_t \sim g_{\theta}(y_t|x_t) = \mathcal{N}(0, \exp(\frac{1}{2}x_t)^2)
# ```
# with the initial distribution $f_0(x) = \mathcal{N}(0, q^2)$.
# Here we assume the static parameters $\theta = (q^2, r^2)$ are known and we are only interested in sampling from the latent state $x_t$.
Parameters = @NamedTuple begin
a::Float64
q::Float64
T::Int
end
f(θ::Parameters, state, t) = Normal(θ.a * state, θ.q)
g(θ::Parameters, state, t) = Normal(0, exp(0.5 * state))
f₀(θ::Parameters) = Normal(0, θ.q)
#md nothing #hide
# Let's simulate some data
a = 0.9 # State Variance
q = 0.5 # Observation variance
Tₘ = 200 # Number of observation
Nₚ = 20 # Number of particles
Nₛ = 200 # Number of samples
seed = 1 # Reproduce everything
θ₀ = Parameters((a, q, Tₘ))
rng = Random.MersenneTwister(seed)
x = zeros(Tₘ)
y = zeros(Tₘ)
x[1] = 0
for t in 1:Tₘ
if t < Tₘ
x[t + 1] = rand(rng, f(θ₀, x[t], t))
end
y[t] = rand(rng, g(θ₀, x[t], t))
end
# Here are the latent and observation series:
plot(x; label="x", xlabel="t")
#
plot(y; label="y", xlabel="t")
# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
mutable struct NonLinearTimeSeries <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters
NonLinearTimeSeries(θ::Parameters) = new(Float64[], θ)
NonLinearTimeSeries(y::Vector{Float64}, θ::Parameters) = new(Float64[], y, θ)
end
# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
function SSMProblems.transition!!(rng::AbstractRNG, model::NonLinearTimeSeries)
return rand(rng, f₀(model.θ))
end
function SSMProblems.transition!!(
rng::AbstractRNG, model::NonLinearTimeSeries, state::Float64, step::Int
)
return rand(rng, f(model.θ, state, step))
end
function SSMProblems.emission_logdensity(
modeL::NonLinearTimeSeries, state::Float64, step::Int
)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(
model::NonLinearTimeSeries, prev_state, current_state, step
)
return logpdf(f(model.θ, prev_state, step), current_state)
end
# We need to tell AdvancedPS when to stop the execution of the model
# TODO
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ
# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(y, θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide
particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = mean(particles; dims=2);
#md nothing #hide
# We can now plot all the generated traces.
# Beyond the last few timesteps all the trajectories collapse into one. Using the ancestor updating step can help with the degeneracy problem, as we show below.
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
# We can also check the mixing as defined in the Gaussian State Space model example. As seen on the
# scatter plot above, we are mostly left with a single trajectory before timestep 150. The orange
# bar is the optimal mixing rate for the number of particles we use.
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)
# Let's see if ancestor sampling can help with the degeneracy problem. We use the same number of particles, but replace the sampler with PGAS.
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false);
particles = hcat([chain.trajectory.model.X for chain in chains]...);
mean_trajectory = mean(particles; dims=2);
# The ancestor sampling has helped with the degeneracy problem and we now have a much more diverse set of trajectories, also at earlier time periods.
scatter(
particles[:, 1:50]; label=false, opacity=0.5, color=:black, xlabel="t", ylabel="state"
)
plot!(x; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
# The update rate is now much higher throughout time.
plot_update_rate(update_rate(particles, Nₛ)[:, 1], Nₚ)