Skip to content

Commit

Permalink
Merge fb5c8ad into 1e5dfdd
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Mar 24, 2024
2 parents 1e5dfdd + fb5c8ad commit 5b7772f
Show file tree
Hide file tree
Showing 16 changed files with 131 additions and 89 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[weakdeps]
Expand All @@ -21,10 +22,10 @@ AdvancedPSLibtaskExt = "Libtask"
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Random = "1.6"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
Random = "1.6"
julia = "1.6"

[extras]
Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-process/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
35 changes: 23 additions & 12 deletions examples/gaussian-process/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ using AbstractGPs
using Plots
using Distributions
using Libtask
using SSMProblems

Parameters = @NamedTuple begin
a::Float64
q::Float64
kernel
end

mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters

GPSSM(params::Parameters) = new(Vector{Float64}(), params)
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
end

seed = 1
Expand All @@ -29,21 +32,20 @@ q = 0.5

params = Parameters((a, q, SqExponentialKernel()))

f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
h(model::GPSSM) = Normal(0, model.θ.q)
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)
f(θ::Parameters, x, t) = Normal.a * x, θ.q)
h(θ::Parameters) = Normal(0, θ.q)
g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)

rng = Random.MersenneTwister(seed)
ref_model = GPSSM(params)

x = zeros(T)
y = similar(x)
x[1] = rand(rng, h(ref_model))
x[1] = rand(rng, h(params))
for t in 1:T
if t < T
x[t + 1] = rand(rng, f(ref_model, x[t], t))
x[t + 1] = rand(rng, f(params, x[t], t))
end
y[t] = rand(rng, g(ref_model, x[t], t))
y[t] = rand(rng, g(params, x[t], t))
end

function gp_update(model::GPSSM, state, step)
Expand All @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step)
return Normal(μ[1], σ[1])
end

AdvancedPS.initialization(::GPSSM) = h(model)
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
return rand(rng, gp_update(model, state, step))
end

function SSMProblems.emission_logdensity(model::GPSSM, state, step)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
return logpdf(gp_update(model, prev_state, step), current_state)
end

AdvancedPS.isdone(::GPSSM, step) = step > T

model = GPSSM(params)
model = GPSSM(y, params)
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pg, Nₛ)

Expand Down
1 change: 1 addition & 0 deletions examples/gaussian-ssm/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
42 changes: 27 additions & 15 deletions examples/gaussian-ssm/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using AdvancedPS
using Random
using Distributions
using Plots
using SSMProblems

# We consider the following linear state-space model with Gaussian innovations. The latent state is a simple gaussian random walk
# and the observation is linear in the latent states, namely:
Expand Down Expand Up @@ -33,32 +34,45 @@ Parameters = @NamedTuple begin
r::Float64
end

mutable struct LinearSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct LinearSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters
LinearSSM::Parameters) = new(Vector{Float64}(), θ)
LinearSSM(y::Vector, θ::Parameters) = new(Vector{Float64}(), y, θ)
end

# and the densities defined above.
f(m::LinearSSM, state, t) = Normal(m.θ.a * state, m.θ.q) # Transition density
g(m::LinearSSM, state, t) = Normal(state, m.θ.r) # Observation density
f₀(m::LinearSSM) = Normal(0, m.θ.q^2 / (1 - m.θ.a^2)) # Initial state density
f(θ::Parameters, state, t) = Normal.a * state, θ.q) # Transition density
g(θ::Parameters, state, t) = Normal(state, θ.r) # Observation density
f₀(θ::Parameters) = Normal(0, θ.q^2 / (1 - θ.a^2)) # Initial state density
#md nothing #hide

# We also need to specify the dynamics of the system through the transition equations:
# - `AdvancedPS.initialization`: the initial state density
# - `AdvancedPS.transition`: the state transition density
# - `AdvancedPS.observation`: the observation score given the observed data
# - `AdvancedPS.isdone`: signals the end of the execution for the model
AdvancedPS.initialization(model::LinearSSM) = f₀(model)
AdvancedPS.transition(model::LinearSSM, state, step) = f(model, state, step)
function AdvancedPS.observation(model::LinearSSM, state, step)
return logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::LinearSSM) = rand(rng, f₀(model.θ))
function SSMProblems.transition!!(
rng::AbstractRNG, model::LinearSSM, state::Float64, step::Int
)
return rand(rng, f(model.θ, state, step))
end

function SSMProblems.emission_logdensity(modeL::LinearSSM, state::Float64, step::Int)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(
model::LinearSSM, prev_state, current_state, step
)
return logpdf(f(model.θ, prev_state, step), current_state)
end

# We need to think seriously about how the data is handled
AdvancedPS.isdone(::LinearSSM, step) = step > Tₘ

# Everything is now ready to simulate some data.

a = 0.9 # Scale
q = 0.32 # State variance
r = 1 # Observation variance
Expand All @@ -72,14 +86,12 @@ rng = Random.MersenneTwister(seed)

x = zeros(Tₘ)
y = zeros(Tₘ)

reference = LinearSSM(θ₀)
x[1] = rand(rng, f₀(reference))
x[1] = rand(rng, f₀(θ₀))
for t in 1:Tₘ
if t < Tₘ
x[t + 1] = rand(rng, f(reference, x[t], t))
x[t + 1] = rand(rng, f(θ₀, x[t], t))
end
y[t] = rand(rng, g(reference, x[t], t))
y[t] = rand(rng, g(θ₀, x[t], t))
end

# Here are the latent and obseravation timeseries
Expand All @@ -88,7 +100,7 @@ plot!(y; seriestype=:scatter, label="y", xlabel="t", mc=:red, ms=2, ma=0.5)

# `AdvancedPS` subscribes to the `AbstractMCMC` API. To sample we just need to define a Particle Gibbs kernel
# and a model interface.
model = LinearSSM(θ₀)
model = LinearSSM(y, θ₀)
pgas = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pgas, Nₛ; progress=false);
#md nothing #hide
Expand Down
1 change: 1 addition & 0 deletions examples/particle-gibbs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
32 changes: 26 additions & 6 deletions examples/particle-gibbs/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Distributions
using Plots
using AbstractMCMC
using Random123
using SSMProblems

"""
plot_update_rate(update_rate, N)
Expand Down Expand Up @@ -90,22 +91,41 @@ plot(x; label="x", xlabel="t")
plot(y; label="y", xlabel="t")

# Each model takes an `AbstractRNG` as input and generates the logpdf of the current transition:
mutable struct NonLinearTimeSeries <: AdvancedPS.AbstractStateSpaceModel
mutable struct NonLinearTimeSeries <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters
NonLinearTimeSeries::Parameters) = new(Float64[], θ)
NonLinearTimeSeries(y::Vector{Float64}, θ::Parameters) = new(Float64[], y, θ)
end

# The dynamics of the model is defined through the `AbstractStateSpaceModel` interface:
AdvancedPS.initialization(model::NonLinearTimeSeries) = f₀(model.θ)
AdvancedPS.transition(model::NonLinearTimeSeries, state, step) = f(model.θ, state, step)
function AdvancedPS.observation(model::NonLinearTimeSeries, state, step)
return logpdf(g(model.θ, state, step), y[step])
function SSMProblems.transition!!(rng::AbstractRNG, model::NonLinearTimeSeries)
return rand(rng, f₀(model.θ))
end
function SSMProblems.transition!!(
rng::AbstractRNG, model::NonLinearTimeSeries, state::Float64, step::Int
)
return rand(rng, f(model.θ, state, step))
end

function SSMProblems.emission_logdensity(
modeL::NonLinearTimeSeries, state::Float64, step::Int
)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(
model::NonLinearTimeSeries, prev_state, current_state, step
)
return logpdf(f(model.θ, prev_state, step), current_state)
end

# We need to tell AdvancedPS when to stop the execution of the model
# TODO
AdvancedPS.isdone(::NonLinearTimeSeries, step) = step > Tₘ

# Here we use the particle gibbs kernel without adaptive resampling.
model = NonLinearTimeSeries(θ₀)
model = NonLinearTimeSeries(y, θ₀)
pg = AdvancedPS.PG(Nₚ, 1.0)
chains = sample(rng, model, pg, Nₛ; progress=false);
#md nothing #hide
Expand Down
1 change: 1 addition & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Distributions: Distributions
using Random: Random
using StatsFuns: StatsFuns
using Random123: Random123
using SSMProblems: SSMProblems

abstract type AbstractParticleModel <: AbstractMCMC.AbstractModel end

Expand Down
2 changes: 1 addition & 1 deletion src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ function reweight!(pc::ParticleContainer, ref::Union{Particle,Nothing}=nothing)
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
isref = p === ref
score = advance!(p, isref)
score = advance!(p, isref) # SSMProblems.transition!!

if score === nothing
numdone += 1
Expand Down
9 changes: 8 additions & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,20 @@ mutable struct Trace{F,R}
end

const Particle = Trace
const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R}
const SSMTrace{R} = Trace{<:SSMProblems.AbstractStateSpaceModel,R}
const GenericTrace{R} = Trace{<:AbstractGenericModel,R}

reset_logprob!(::AdvancedPS.Particle) = nothing
reset_model(f) = deepcopy(f)
delete_retained!(f) = nothing

"""
isdone(model::SSMProblems.AbstractStateSpaceModel, step)
Returns `true` if we reached the end of the model execution
"""
function isdone end

"""
copy(trace::Trace)
Expand Down
48 changes: 9 additions & 39 deletions src/pgas.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,3 @@
"""
initialization(model::AbstractStateSpaceModel)
Define the distribution of the initial state of the State Space Model
"""
function initialization end

"""
transition(model::AbstractStateSpaceModel, state, step)
Define the transition density of the State Space Model
Must return `nothing` if it consumed all the data
"""
function transition end

"""
observation(model::AbstractStateSpaceModel, state, step)
Return the log-likelihood of the observed measurement conditional on the current state of the model.
Must return `nothing` if it consumed all the data
"""
function observation end

"""
isdone(model::AbstractStateSpaceModel, step)
Return `true` if model reached final state else `false`
"""
function isdone end

"""
previous_state(trace::SSMTrace)
Expand All @@ -54,13 +24,11 @@ current_step(trace::SSMTrace) = trace.rng.count
Get the log weight of the transition from previous state of `model` to `x`
"""
function transition_logweight(particle::SSMTrace, x)
score = Distributions.logpdf(
transition(
particle.model,
particle.model.X[current_step(particle) - 2],
current_step(particle) - 2,
),
score = SSMProblems.transition_logdensity(
particle.model,
particle.model.X[current_step(particle) - 2],
x,
current_step(particle) - 1,
)
return score
end
Expand Down Expand Up @@ -93,16 +61,18 @@ function advance!(particle::SSMTrace, isref::Bool=false)

if !isref
if running_step == 1
new_state = rand(particle.rng, initialization(model)) # Generate initial state, maybe fallback to 0 if initialization is not defined
new_state = SSMProblems.transition!!(particle.rng, model)
else
current_state = model.X[running_step - 1]
new_state = rand(particle.rng, transition(model, current_state, running_step))
new_state = SSMProblems.transition!!(
particle.rng, model, current_state, running_step
)
end
else
new_state = model.X[running_step] # We need the current state from the reference particle
end

score = observation(model, new_state, running_step)
score = SSMProblems.emission_logdensity(model, new_state, running_step)

# accept transition
!isref && push!(model.X, new_state)
Expand Down
Loading

0 comments on commit 5b7772f

Please sign in to comment.