Skip to content

Commit

Permalink
Merge 804bda0 into 1e5dfdd
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Mar 25, 2024
2 parents 1e5dfdd + 804bda0 commit 31dd078
Show file tree
Hide file tree
Showing 18 changed files with 394 additions and 89 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Expand All @@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask"
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Random = "1.6"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
Random = "1.6"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-process/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
35 changes: 23 additions & 12 deletions examples/gaussian-process/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ using AbstractGPs
using Plots
using Distributions
using Libtask
using SSMProblems

Parameters = @NamedTuple begin
a::Float64
q::Float64
kernel
end

mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters

GPSSM(params::Parameters) = new(Vector{Float64}(), params)
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
end

seed = 1
Expand All @@ -29,21 +32,20 @@ q = 0.5

params = Parameters((a, q, SqExponentialKernel()))

f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
h(model::GPSSM) = Normal(0, model.θ.q)
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)
f(θ::Parameters, x, t) = Normal.a * x, θ.q)
h(θ::Parameters) = Normal(0, θ.q)
g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)

rng = Random.MersenneTwister(seed)
ref_model = GPSSM(params)

x = zeros(T)
y = similar(x)
x[1] = rand(rng, h(ref_model))
x[1] = rand(rng, h(params))
for t in 1:T
if t < T
x[t + 1] = rand(rng, f(ref_model, x[t], t))
x[t + 1] = rand(rng, f(params, x[t], t))
end
y[t] = rand(rng, g(ref_model, x[t], t))
y[t] = rand(rng, g(params, x[t], t))
end

function gp_update(model::GPSSM, state, step)
Expand All @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step)
return Normal(μ[1], σ[1])
end

AdvancedPS.initialization(::GPSSM) = h(model)
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
return rand(rng, gp_update(model, state, step))
end

function SSMProblems.emission_logdensity(model::GPSSM, state, step)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
return logpdf(gp_update(model, prev_state, step), current_state)
end

AdvancedPS.isdone(::GPSSM, step) = step > T

model = GPSSM(params)
model = GPSSM(y, params)
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pg, Nₛ)

Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-ssm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
42 changes: 27 additions & 15 deletions examples/gaussian-ssm/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using AdvancedPS
using Random
using Distributions
using Plots
using SSMProblems

# We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk
# and the observation is linear in the latent states, namely:
Expand Down Expand Up @@ -33,32 +34,45 @@ Parameters = @NamedTuple begin
r::Float64
end

mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
end

# and the densities defined above.
f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
f(θ::Parameters, state, t) = Normal.a * state, θ.q) # Transition density
g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density
f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
#md nothing #hide

# We also need to specify the dynamics of the system through the transition equations:
# - `AdvancedPS.initialization`: the initial state density
# - `AdvancedPS.transition`: the state transition density
# - `AdvancedPS.observation`: the observation score given the observed data
# - `AdvancedPS.isdone`: signals the end of the execution for the model
AdvancedPS.initialization(model::LinearSSM) = f₀(model)
AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step)
function AdvancedPS.observation(model::LinearSSM, state, step)
return logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
function SSMProblems.transition!!(
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
)
return rand(rng, f(model.θ, state, step))
end

function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(
model::LinearSSM, prev_state, current_state, step
)
return logpdf(f(model.θ, prev_state, step), current_state)
end

# We need to think seriously about how the data is handled
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Expand All @@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = LinearSSM(θ₀)
x[1] = rand(rng, f₀(reference))
x[1] = rand(rng, f₀(θ₀))
for t in 1:Tₘ
if t < Tₘ
x[t + 1] = rand(rng, f(reference, x[t], t))
x[t + 1] = rand(rng, f(θ₀, x[t], t))
end
y[t] = rand(rng, g(reference, x[t], t))
y[t] = rand(rng, g(θ₀, x[t], t))
end

# Here are the latent and obseravation timeseries
Expand All @@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)

# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
# and a model interface.
model = LinearSSM(θ₀)
model = LinearSSM(y, θ₀)
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false);
#md nothing #hide
Expand Down
7 changes: 7 additions & 0 deletions examples/levy-ssm/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
Loading

0 comments on commit 31dd078

Please sign in to comment.