From 364380285aa48ad18aebdc65b6051a5aaec56bce Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 24 Mar 2024 19:27:26 +0000 Subject: [PATCH 1/5] SSMProblems - Draft --- Project.toml | 3 +- examples/gaussian-ssm/Project.toml | 1 + examples/gaussian-ssm/script.jl | 42 +++++++++++++++--------- examples/particle-gibbs/Project.toml | 1 + examples/particle-gibbs/script.jl | 32 +++++++++++++++---- src/AdvancedPS.jl | 1 + src/container.jl | 2 +- src/model.jl | 9 +++++- src/pgas.jl | 48 ++++++---------------------- src/smc.jl | 11 +++++-- 10 files changed, 84 insertions(+), 66 deletions(-) diff --git a/Project.toml b/Project.toml index 4524aa40..384ecb5d 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [weakdeps] @@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask" AbstractMCMC = "2, 3, 4, 5" Distributions = "0.23, 0.24, 0.25" Libtask = "0.8" +Random = "1.6" Random123 = "1.3" Requires = "1.0" StatsFuns = "0.9, 1" -Random = "1.6" julia = "1.6" [extras] diff --git a/examples/gaussian-ssm/Project.toml b/examples/gaussian-ssm/Project.toml index 1093db09..e87cc369 100644 --- a/examples/gaussian-ssm/Project.toml +++ b/examples/gaussian-ssm/Project.toml @@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/examples/gaussian-ssm/script.jl b/examples/gaussian-ssm/script.jl index debd9867..d7a5889c 100644 --- a/examples/gaussian-ssm/script.jl +++ b/examples/gaussian-ssm/script.jl @@ -3,6 +3,7 @@ using AdvancedPS using Random using Distributions using Plots +using SSMProblems # We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk # and the observation is linear in the latent states, namely: @@ -33,16 +34,18 @@ Parameters = @NamedTuple begin r::Float64 end -mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel +mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters LinearSSM(θ::Parameters) = new(Vector{Float64}(), θ) + LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ) end # and the densities defined above. -f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density -g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density -f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density +f(θ::Parameters, state, t) = Normal(θ.a * state, θ.q) # Transition density +g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density +f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density #md nothing #hide # We also need to specify the dynamics of the system through the transition equations: @@ -50,15 +53,26 @@ f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state den # - `AdvancedPS.transition`: the state transition density # - `AdvancedPS.observation`: the observation score given the observed data # - `AdvancedPS.isdone`: signals the end of the execution for the model -AdvancedPS.initialization(model::LinearSSM) = f₀(model) -AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step) -function AdvancedPS.observation(model::LinearSSM, state, step) - return logpdf(g(model, state, step), y[step]) +SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ)) +function SSMProblems.transition!!( + rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int +) + return rand(rng, f(model.θ, state, step)) +end + +function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity( + model::LinearSSM, prev_state, current_state, step +) + return logpdf(f(model.θ, prev_state, step), current_state) end + +# We need to think seriously about how the data is handled AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ # Everything is now ready to simulate some data. - a = 0.9 # Scale q = 0.32 # State variance r = 1 # Observation variance @@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed) x = zeros(Tₘ) y = zeros(Tₘ) - -reference = LinearSSM(θ₀) -x[1] = rand(rng, f₀(reference)) +x[1] = rand(rng, f₀(θ₀)) for t in 1:Tₘ if t < Tₘ - x[t + 1] = rand(rng, f(reference, x[t], t)) + x[t + 1] = rand(rng, f(θ₀, x[t], t)) end - y[t] = rand(rng, g(reference, x[t], t)) + y[t] = rand(rng, g(θ₀, x[t], t)) end # Here are the latent and obseravation timeseries @@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5) # `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel # and a model interface. -model = LinearSSM(θ₀) +model = LinearSSM(y, θ₀) pgas = AdvancedPS.PGAS(Nₚ) chains = sample(rng, model, pgas, Nₛ; progress=false); #md nothing #hide diff --git a/examples/particle-gibbs/Project.toml b/examples/particle-gibbs/Project.toml index d8ae6a53..6fa1fbb2 100644 --- a/examples/particle-gibbs/Project.toml +++ b/examples/particle-gibbs/Project.toml @@ -7,4 +7,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" diff --git a/examples/particle-gibbs/script.jl b/examples/particle-gibbs/script.jl index baa6ceb1..f7a9d2aa 100644 --- a/examples/particle-gibbs/script.jl +++ b/examples/particle-gibbs/script.jl @@ -5,6 +5,7 @@ using Distributions using Plots using AbstractMCMC using Random123 +using SSMProblems """ plot_update_rate(update_rate, N) @@ -90,22 +91,41 @@ plot(x; label="x", xlabel="t") plot(y; label="y", xlabel="t") # Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition: -mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel +mutable struct NonLinearTimeSeries <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters NonLinearTimeSeries(θ::Parameters) = new(Float64[], θ) + NonLinearTimeSeries(y::Vector{Float64}, θ::Parameters) = new(Float64[], y, θ) end # The dynamics of the model is defined through the `AbstractStateSpaceModel` interface: -AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ) -AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step) -function AdvancedPS.observation(model::NonLinearTimeSeries, state, step) - return logpdf(g(model.θ, state, step), y[step]) +function SSMProblems.transition!!(rng::AbstractRNG, model::NonLinearTimeSeries) + return rand(rng, f₀(model.θ)) end +function SSMProblems.transition!!( + rng::AbstractRNG, model::NonLinearTimeSeries, state::Float64, step::Int +) + return rand(rng, f(model.θ, state, step)) +end + +function SSMProblems.emission_logdensity( + modeL::NonLinearTimeSeries, state::Float64, step::Int +) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity( + model::NonLinearTimeSeries, prev_state, current_state, step +) + return logpdf(f(model.θ, prev_state, step), current_state) +end + +# We need to tell AdvancedPS when to stop the execution of the model +# TODO AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ # Here we use the particle gibbs kernel without adaptive resampling. -model = NonLinearTimeSeries(θ₀) +model = NonLinearTimeSeries(y, θ₀) pg = AdvancedPS.PG(Nₚ, 1.0) chains = sample(rng, model, pg, Nₛ; progress=false); #md nothing #hide diff --git a/src/AdvancedPS.jl b/src/AdvancedPS.jl index 8ce7c2b7..faa673e8 100644 --- a/src/AdvancedPS.jl +++ b/src/AdvancedPS.jl @@ -5,6 +5,7 @@ using Distributions: Distributions using Random: Random using StatsFuns: StatsFuns using Random123: Random123 +using SSMProblems: SSMProblems abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end diff --git a/src/container.jl b/src/container.jl index c5d693d5..3ac64fef 100644 --- a/src/container.jl +++ b/src/container.jl @@ -269,7 +269,7 @@ function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing) # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and # ``θᵢ`` are variables of other samplers. isref = p === ref - score = advance!(p, isref) + score = advance!(p, isref) # SSMProblems.transition!! if score === nothing numdone += 1 diff --git a/src/model.jl b/src/model.jl index fab9a34f..b7b72db3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,13 +7,20 @@ mutable struct Trace{F,R} end const Particle = Trace -const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R} +const SSMTrace{R} = Trace{<:SSMProblems.AbstractStateSpaceModel,R} const GenericTrace{R} = Trace{<:AbstractGenericModel,R} reset_logprob!(::AdvancedPS.Particle) = nothing reset_model(f) = deepcopy(f) delete_retained!(f) = nothing +""" + isdone(model::SSMProblems.AbstractStateSpaceModel, step) + +Returns `true` if we reached the end of the model execution +""" +function isdone end + """ copy(trace::Trace) diff --git a/src/pgas.jl b/src/pgas.jl index 8dfb449e..41bcfea7 100644 --- a/src/pgas.jl +++ b/src/pgas.jl @@ -1,33 +1,3 @@ -""" - initialization(model::AbstractStateSpaceModel) - -Define the distribution of the initial state of the State Space Model -""" -function initialization end - -""" - transition(model::AbstractStateSpaceModel, state, step) - -Define the transition density of the State Space Model -Must return `nothing` if it consumed all the data -""" -function transition end - -""" - observation(model::AbstractStateSpaceModel, state, step) - -Return the log-likelihood of the observed measurement conditional on the current state of the model. -Must return `nothing` if it consumed all the data -""" -function observation end - -""" - isdone(model::AbstractStateSpaceModel, step) - -Return `true` if model reached final state else `false` -""" -function isdone end - """ previous_state(trace::SSMTrace) @@ -54,13 +24,11 @@ current_step(trace::SSMTrace) = trace.rng.count Get the log weight of the transition from previous state of `model` to `x` """ function transition_logweight(particle::SSMTrace, x) - score = Distributions.logpdf( - transition( - particle.model, - particle.model.X[current_step(particle) - 2], - current_step(particle) - 2, - ), + score = SSMProblems.transition_logdensity( + particle.model, + particle.model.X[current_step(particle) - 2], x, + current_step(particle) - 1, ) return score end @@ -93,16 +61,18 @@ function advance!(particle::SSMTrace, isref::Bool=false) if !isref if running_step == 1 - new_state = rand(particle.rng, initialization(model)) # Generate initial state, maybe fallback to 0 if initialization is not defined + new_state = SSMProblems.transition!!(particle.rng, model) else current_state = model.X[running_step - 1] - new_state = rand(particle.rng, transition(model, current_state, running_step)) + new_state = SSMProblems.transition!!( + particle.rng, model, current_state, running_step + ) end else new_state = model.X[running_step] # We need the current state from the reference particle end - score = observation(model, new_state, running_step) + score = SSMProblems.emission_logdensity(model, new_state, running_step) # accept transition !isref && push!(model.X, new_state) diff --git a/src/smc.jl b/src/smc.jl index c45b127a..6d982d2b 100644 --- a/src/smc.jl +++ b/src/smc.jl @@ -26,12 +26,17 @@ struct SMCSample{P,W,L} logevidence::L end -function AbstractMCMC.sample(model::AbstractStateSpaceModel, sampler::SMC; kwargs...) +function AbstractMCMC.sample( + model::SSMProblems.AbstractStateSpaceModel, sampler::SMC; kwargs... +) return AbstractMCMC.sample(Random.GLOBAL_RNG, model, sampler; kwargs...) end function AbstractMCMC.sample( - rng::Random.AbstractRNG, model::AbstractStateSpaceModel, sampler::SMC; kwargs... + rng::Random.AbstractRNG, + model::SSMProblems.AbstractStateSpaceModel, + sampler::SMC; + kwargs..., ) if !isempty(kwargs) @warn "keyword arguments $(keys(kwargs)) are not supported by `SMC`" @@ -95,7 +100,7 @@ PGAS(nparticles::Int) = PGAS(nparticles, ResampleWithESSThreshold(1.0)) function AbstractMCMC.step( rng::Random.AbstractRNG, - model::AbstractStateSpaceModel, + model::SSMProblems.AbstractStateSpaceModel, sampler::Union{PGAS,PG}, state::Union{PGState,Nothing}=nothing; kwargs..., From 224993342e050ddf4d32e29341235baeb88df36b Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 24 Mar 2024 19:39:31 +0000 Subject: [PATCH 2/5] Tests --- test/Project.toml | 1 + test/container.jl | 11 +++++++---- test/pgas.jl | 13 +++++-------- test/runtests.jl | 1 + 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 9fcb69de..686826f8 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Random123 = "74087812-796a-5b5d-8853-05524746bad3" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/container.jl b/test/container.jl index 7a312291..ba266499 100644 --- a/test/container.jl +++ b/test/container.jl @@ -1,14 +1,17 @@ @testset "container.jl" begin # Since the extension would hide the low level function call API - mutable struct LogPModel{T} <: AdvancedPS.AbstractStateSpaceModel + mutable struct LogPModel{T} <: SSMProblems.AbstractStateSpaceModel logp::T X::Array{T} end - AdvancedPS.initialization(model::LogPModel) = Uniform() - AdvancedPS.transition(model::LogPModel, state, step) = Uniform() - AdvancedPS.observation(model::LogPModel, state, step) = model.logp + SSMProblems.transition!!(rng::AbstractRNG, model::LogPModel) = rand(rng, Uniform()) + function SSMProblems.transition!!(rng::AbstractRNG, model::LogPModel, state, step) + return rand(rng, Uniform()) + end + SSMProblems.emission_logdensity(model::LogPModel, state, step) = model.logp + AdvancedPS.isdone(model::LogPModel, step) = false @testset "copy particle container" begin diff --git a/test/pgas.jl b/test/pgas.jl index 8c7b3160..7d91350e 100644 --- a/test/pgas.jl +++ b/test/pgas.jl @@ -5,19 +5,16 @@ r::Float64 end - mutable struct BaseModel <: AdvancedPS.AbstractStateSpaceModel + mutable struct BaseModel <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} θ::Params BaseModel(params::Params) = new(Vector{Float64}(), params) end - AdvancedPS.initialization(model::BaseModel) = Normal(0, model.θ.q) - function AdvancedPS.transition(model::BaseModel, state, step) - return Distributions.Normal(model.θ.a * state, model.θ.q) - end - function AdvancedPS.observation(model::BaseModel, state, step) - return Distributions.logpdf(Distributions.Normal(state, model.θ.r), 0) - end + SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel) = rand(rng, Normal(0, model.θ.q)) + SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step) = rand(rng, Normal(model.θ.a * state, model.θ.q)) + SSMProblems.emission_logdensity(model::BaseModel, state, step) = logpdf(Distributions.Normal(state, model.θ.r), 0) + SSMProblems.transition_logdensity(model::BaseModel, prev_state, current_state, step) = logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state) AdvancedPS.isdone(::BaseModel, step) = step > 3 diff --git a/test/runtests.jl b/test/runtests.jl index d7058062..b2c0990c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Distributions using Libtask using Random using Test +using SSMProblems @testset "AdvancedPS.jl" begin @testset "Resampling tests" begin From fb5c8ad809e610216e82a73603830a65764cae3e Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Sun, 24 Mar 2024 19:59:31 +0000 Subject: [PATCH 3/5] Format, GP-SSM --- examples/gaussian-process/Project.toml | 1 + examples/gaussian-process/script.jl | 35 +++++++++++++++++--------- test/pgas.jl | 20 +++++++++++---- 3 files changed, 39 insertions(+), 17 deletions(-) diff --git a/examples/gaussian-process/Project.toml b/examples/gaussian-process/Project.toml index e2796748..9a45895d 100644 --- a/examples/gaussian-process/Project.toml +++ b/examples/gaussian-process/Project.toml @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/examples/gaussian-process/script.jl b/examples/gaussian-process/script.jl index d70a9cda..1a58bde3 100644 --- a/examples/gaussian-process/script.jl +++ b/examples/gaussian-process/script.jl @@ -6,6 +6,7 @@ using AbstractGPs using Plots using Distributions using Libtask +using SSMProblems Parameters = @NamedTuple begin a::Float64 @@ -13,11 +14,13 @@ Parameters = @NamedTuple begin kernel end -mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel +mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel X::Vector{Float64} + observations::Vector{Float64} θ::Parameters GPSSM(params::Parameters) = new(Vector{Float64}(), params) + GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params) end seed = 1 @@ -29,21 +32,20 @@ q = 0.5 params = Parameters((a, q, SqExponentialKernel())) -f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q) -h(model::GPSSM) = Normal(0, model.θ.q) -g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2) +f(θ::Parameters, x, t) = Normal(θ.a * x, θ.q) +h(θ::Parameters) = Normal(0, θ.q) +g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2) rng = Random.MersenneTwister(seed) -ref_model = GPSSM(params) x = zeros(T) y = similar(x) -x[1] = rand(rng, h(ref_model)) +x[1] = rand(rng, h(params)) for t in 1:T if t < T - x[t + 1] = rand(rng, f(ref_model, x[t], t)) + x[t + 1] = rand(rng, f(params, x[t], t)) end - y[t] = rand(rng, g(ref_model, x[t], t)) + y[t] = rand(rng, g(params, x[t], t)) end function gp_update(model::GPSSM, state, step) @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step) return Normal(μ[1], σ[1]) end -AdvancedPS.initialization(::GPSSM) = h(model) -AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step) -AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step]) +SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ)) +function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step) + return rand(rng, gp_update(model, state, step)) +end + +function SSMProblems.emission_logdensity(model::GPSSM, state, step) + return logpdf(g(model.θ, state, step), model.observations[step]) +end +function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step) + return logpdf(gp_update(model, prev_state, step), current_state) +end + AdvancedPS.isdone(::GPSSM, step) = step > T -model = GPSSM(params) +model = GPSSM(y, params) pg = AdvancedPS.PGAS(Nₚ) chains = sample(rng, model, pg, Nₛ) diff --git a/test/pgas.jl b/test/pgas.jl index 7d91350e..3a701db7 100644 --- a/test/pgas.jl +++ b/test/pgas.jl @@ -11,10 +11,20 @@ BaseModel(params::Params) = new(Vector{Float64}(), params) end - SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel) = rand(rng, Normal(0, model.θ.q)) - SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step) = rand(rng, Normal(model.θ.a * state, model.θ.q)) - SSMProblems.emission_logdensity(model::BaseModel, state, step) = logpdf(Distributions.Normal(state, model.θ.r), 0) - SSMProblems.transition_logdensity(model::BaseModel, prev_state, current_state, step) = logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state) + function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel) + return rand(rng, Normal(0, model.θ.q)) + end + function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step) + return rand(rng, Normal(model.θ.a * state, model.θ.q)) + end + function SSMProblems.emission_logdensity(model::BaseModel, state, step) + return logpdf(Distributions.Normal(state, model.θ.r), 0) + end + function SSMProblems.transition_logdensity( + model::BaseModel, prev_state, current_state, step + ) + return logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state) + end AdvancedPS.isdone(::BaseModel, step) = step > 3 @@ -80,7 +90,7 @@ seed = 10 rng = Random.MersenneTwister(seed) - for sampler in [AdvancedPS.PGAS(10)] + for sampler in [AdvancedPS.PGAS(10), AdvancedPS.PG(10)] Random.seed!(rng, seed) chain1 = sample(rng, model, sampler, 10) vals1 = hcat([chain.trajectory.model.X for chain in chain1]...) From aef19fae74c5a93ec60ddf7dd8d37254a8a782a6 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Mon, 25 Mar 2024 19:15:45 +0000 Subject: [PATCH 4/5] Add levy-ssm --- examples/levy-ssm/Project.toml | 7 + examples/levy-ssm/script.jl | 255 +++++++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+) create mode 100644 examples/levy-ssm/Project.toml create mode 100644 examples/levy-ssm/script.jl diff --git a/examples/levy-ssm/Project.toml b/examples/levy-ssm/Project.toml new file mode 100644 index 00000000..572ec481 --- /dev/null +++ b/examples/levy-ssm/Project.toml @@ -0,0 +1,7 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +AdvancedPS = "576499cb-2369-40b2-a588-c64705576edc" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/examples/levy-ssm/script.jl b/examples/levy-ssm/script.jl new file mode 100644 index 00000000..f7e9b62d --- /dev/null +++ b/examples/levy-ssm/script.jl @@ -0,0 +1,255 @@ +# # Levy-SSM latent state inference +using AdvancedPS: SSMProblems +using AdvancedPS +using Random +using Plots +using Distributions +using AdvancedPS +using LinearAlgebra +using SSMProblems + +struct GammaProcess + C::Float64 + β::Float64 + tol::Float64 +end + +struct GammaPath{T} + jumps::Vector{T} + times::Vector{T} +end + +struct LangevinDynamics{T} + A::Matrix{T} + L::Vector{T} + θ::T + H::Vector{T} + σe::T +end + +struct NormalMeanVariance{T} + μ::T + σ::T +end + +function simulate( + rng::AbstractRNG, + process::GammaProcess, + rate::Float64, + start::Float64, + finish::Float64, + t0::Float64=0.0, +) + let β = process.β, C = process.C, tolerance = process.tol + jumps = Float64[] + last_jump = Inf + t = t0 + truncated = last_jump < tolerance + while !truncated + t += rand(rng, Exponential(1.0 / rate)) + xi = 1.0 / (β * (exp(t / C) - 1)) + prob = (1.0 + β * xi) * exp(-β * xi) + if rand(rng) < prob + push!(jumps, xi) + last_jump = xi + end + truncated = last_jump < tolerance + end + times = rand(rng, Uniform(start, finish), length(jumps)) + return GammaPath(jumps, times) + end +end + +function integral(times::Array{Float64}, path::GammaPath) + let jumps = path.jumps, jump_times = path.times + return [sum(jumps[jump_times .<= t]) for t in times] + end +end + +# Gamma Process +C = 1.0 +β = 1.0 +ϵ = 1e-10 +process = GammaProcess(C, β, ϵ) + +# Normal Mean-Variance representation +μw = 0.0 +σw = 1.0 +nvm = NormalMeanVariance(μw, σw) + +# Levy SSM with Langevin dynamics +# dx(t) = A x(t) dt + L dW(t) +# y(t) = H x(t) + ϵ(t) +θ = -0.5 +A = [ + 0.0 1.0 + 0.0 θ +] +L = [0.0; 1.0] +σe = 1.0 +H = [1.0, 0] +dyn = LangevinDynamics(A, L, θ, H, σe) + +# Simulation parameters +start, finish = 0, 100 +N = 200 +ts = range(start, finish; length=N) +seed = 4 +rng = Random.MersenneTwister(seed) +Np = 10 +Ns = 10 + +f(dt, θ) = exp(θ * dt) +function Base.exp(dyn::LangevinDynamics, dt::Real) + let θ = dyn.θ + f_val = f(dt, θ) + return [1.0 (f_val - 1)/θ; 0 f_val] + end +end + +function meancov( + t::T, dyn::LangevinDynamics, path::GammaPath, nvm::NormalMeanVariance +) where {T<:Real} + μ = zeros(T, 2) + Σ = zeros(T, (2, 2)) + let times = path.times, jumps = path.jumps, μw = nvm.μ, σw = nvm.σ + for (v, z) in zip(times, jumps) + ft = exp(dyn, (t - v)) * dyn.L + μ += ft .* μw .* z + Σ += ft * transpose(ft) .* σw^2 .* z + end + return μ, Σ + end +end + +X = zeros(Float64, (N, 2)) +Y = zeros(Float64, N) +for (i, t) in enumerate(ts) + if i > 1 + s = ts[i - 1] + dt = t - s + path = simulate(rng, process, dt, s, t, ϵ) + μ, Σ = meancov(t, dyn, path, nvm) + X[i, :] .= rand(rng, MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ)) + end + + let H = dyn.H, σe = dyn.σe + Y[i] = transpose(H) * X[i, :] + rand(rng, Normal(0, σe)) + end +end + +# AdvancedPS +Parameters = @NamedTuple begin + dyn::LangevinDynamics + process::GammaProcess + nvm::NormalMeanVariance + times::Vector{Float64} +end + +struct MixedState{T} + x::Vector{T} + path::GammaPath{T} +end + +mutable struct LevyLangevin <: SSMProblems.AbstractStateSpaceModel + X::Vector{MixedState{Float64}} + observations::Vector{Float64} + θ::Parameters + LevyLangevin(θ::Parameters) = new(Vector{MixedState{Float64}}(), θ) + function LevyLangevin(y::Vector{Float64}, θ::Parameters) + return new(Vector{MixedState{Float64}}(), y, θ) + end +end + +function SSMProblems.transition!!(rng::AbstractRNG, model::LevyLangevin) + return MixedState( + rand(rng, MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[]) + ) +end + +function SSMProblems.transition!!( + rng::AbstractRNG, model::LevyLangevin, state::MixedState, step +) + times = model.θ.times + s = times[step - 1] + t = times[step] + dt = t - s + path = simulate(rng, model.θ.process, dt, s, t) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return MixedState(rand(rng, MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path) +end + +function SSMProblems.transition_logdensity( + model::LevyLangevin, prev_state::MixedState, current_state::MixedState, step +) + times = model.θ.times + s = times[step - 1] + t = times[step] + dt = t - s + path = simulate(rng, model.θ.process, dt, s, t) + μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm) + Σ += 1e-6 * I + return logpdf(MultivariateNormal(exp(dyn, dt) * prev_state.x + μ, Σ), current_state.x) +end + +function SSMProblems.emission_logdensity(model::LevyLangevin, state::MixedState, step) + return logpdf(Normal(transpose(H) * state.x, σe), model.observations[step]) +end + +AdvancedPS.isdone(model::LevyLangevin, step) = step > length(model.θ.times) + +θ₀ = Parameters((dyn, process, nvm, ts)) +model = LevyLangevin(Y, θ₀) +pg = AdvancedPS.PGAS(Np) +chains = sample(rng, model, pg, Ns; progress=false); + +# Concat all sampled states +particles = hcat([chain.trajectory.model.X for chain in chains]...) +marginal_states = map(s -> s.x, particles); +jump_times = map(s -> s.path.times, particles); +jump_intensities = map(s -> s.path.jumps, particles); + +# Plot marginal state and jump intensities for one trajectory +p1 = plot( + ts, + [state[1] for state in marginal_states[:, end]]; + color=:darkorange, + label="Marginal State (x1)", +) +plot!( + ts, + [state[2] for state in marginal_states[:, end]]; + color=:dodgerblue, + label="Marginal State (x2)", +) + +p2 = scatter( + vcat([t for t in jump_times[:, end]]...), + vcat([j for j in jump_intensities[:, end]]...); + color=:darkorange, + label="Jumps", +) + +plot( + p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600) +) + +# Plot mean trajectory with standard deviation +mean_trajectory = transpose(hcat(mean(marginal_states; dims=2)...)) +std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3) + +ps = [] +for d in 1:2 + p = plot( + mean_trajectory[:, d]; + ribbon=2 * std_trajectory[:, d]', + color=:darkorange, + label="Mean Trajectory (±2σ)", + fillalpha=0.2, + title="Marginal State Trajectories (X$d)", + ) + plot!(p, X[:, d]; color=:dodgerblue, label="True Trajectory") + push!(ps, p) +end +plot(ps...; layout=(2, 1), size=(600, 600)) From 804bda040a1447d789e2f8510a79320550053ce2 Mon Sep 17 00:00:00 2001 From: FredericWantiez Date: Mon, 25 Mar 2024 19:36:52 +0000 Subject: [PATCH 5/5] Align timesteps --- examples/levy-ssm/script.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/levy-ssm/script.jl b/examples/levy-ssm/script.jl index f7e9b62d..2b35858f 100644 --- a/examples/levy-ssm/script.jl +++ b/examples/levy-ssm/script.jl @@ -96,8 +96,8 @@ N = 200 ts = range(start, finish; length=N) seed = 4 rng = Random.MersenneTwister(seed) -Np = 10 -Ns = 10 +Np = 50 +Ns = 100 f(dt, θ) = exp(θ * dt) function Base.exp(dyn::LangevinDynamics, dt::Real) @@ -242,6 +242,7 @@ std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3) ps = [] for d in 1:2 p = plot( + ts, mean_trajectory[:, d]; ribbon=2 * std_trajectory[:, d]', color=:darkorange, @@ -249,7 +250,7 @@ for d in 1:2 fillalpha=0.2, title="Marginal State Trajectories (X$d)", ) - plot!(p, X[:, d]; color=:dodgerblue, label="True Trajectory") + plot!(p, ts, X[:, d]; color=:dodgerblue, label="True Trajectory") push!(ps, p) end plot(ps...; layout=(2, 1), size=(600, 600))