Skip to content

Commit

Permalink
Format, GP-SSM
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Mar 24, 2024
1 parent 2249933 commit fb5c8ad
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 17 deletions.
1 change: 1 addition & 0 deletions examples/gaussian-process/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
35 changes: 23 additions & 12 deletions examples/gaussian-process/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,21 @@ using AbstractGPs
using Plots
using Distributions
using Libtask
using SSMProblems

Parameters = @NamedTuple begin
a::Float64
q::Float64
kernel
end

mutable struct GPSSM <: AdvancedPS.AbstractStateSpaceModel
mutable struct GPSSM <: SSMProblems.AbstractStateSpaceModel
X::Vector{Float64}
observations::Vector{Float64}
θ::Parameters

GPSSM(params::Parameters) = new(Vector{Float64}(), params)
GPSSM(y::Vector{Float64}, params::Parameters) = new(Vector{Float64}(), y, params)
end

seed = 1
Expand All @@ -29,21 +32,20 @@ q = 0.5

params = Parameters((a, q, SqExponentialKernel()))

f(model::GPSSM, x, t) = Normal(model.θ.a * x, model.θ.q)
h(model::GPSSM) = Normal(0, model.θ.q)
g(model::GPSSM, x, t) = Normal(0, exp(0.5 * x)^2)
f(θ::Parameters, x, t) = Normal.a * x, θ.q)
h(θ::Parameters) = Normal(0, θ.q)
g(θ::Parameters, x, t) = Normal(0, exp(0.5 * x)^2)

rng = Random.MersenneTwister(seed)
ref_model = GPSSM(params)

x = zeros(T)
y = similar(x)
x[1] = rand(rng, h(ref_model))
x[1] = rand(rng, h(params))
for t in 1:T
if t < T
x[t + 1] = rand(rng, f(ref_model, x[t], t))
x[t + 1] = rand(rng, f(params, x[t], t))
end
y[t] = rand(rng, g(ref_model, x[t], t))
y[t] = rand(rng, g(params, x[t], t))
end

function gp_update(model::GPSSM, state, step)
Expand All @@ -54,12 +56,21 @@ function gp_update(model::GPSSM, state, step)
return Normal(μ[1], σ[1])
end

AdvancedPS.initialization(::GPSSM) = h(model)
AdvancedPS.transition(model::GPSSM, state, step) = gp_update(model, state, step)
AdvancedPS.observation(model::GPSSM, state, step) = logpdf(g(model, state, step), y[step])
SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM) = rand(rng, h(model.θ))
function SSMProblems.transition!!(rng::AbstractRNG, model::GPSSM, state, step)
return rand(rng, gp_update(model, state, step))
end

function SSMProblems.emission_logdensity(model::GPSSM, state, step)
return logpdf(g(model.θ, state, step), model.observations[step])
end
function SSMProblems.transition_logdensity(model::GPSSM, prev_state, current_state, step)
return logpdf(gp_update(model, prev_state, step), current_state)
end

AdvancedPS.isdone(::GPSSM, step) = step > T

model = GPSSM(params)
model = GPSSM(y, params)
pg = AdvancedPS.PGAS(Nₚ)
chains = sample(rng, model, pg, Nₛ)

Expand Down
20 changes: 15 additions & 5 deletions test/pgas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@
BaseModel(params::Params) = new(Vector{Float64}(), params)
end

SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel) = rand(rng, Normal(0, model.θ.q))
SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step) = rand(rng, Normal(model.θ.a * state, model.θ.q))
SSMProblems.emission_logdensity(model::BaseModel, state, step) = logpdf(Distributions.Normal(state, model.θ.r), 0)
SSMProblems.transition_logdensity(model::BaseModel, prev_state, current_state, step) = logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state)
function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel)
return rand(rng, Normal(0, model.θ.q))
end
function SSMProblems.transition!!(rng::AbstractRNG, model::BaseModel, state, step)
return rand(rng, Normal(model.θ.a * state, model.θ.q))
end
function SSMProblems.emission_logdensity(model::BaseModel, state, step)
return logpdf(Distributions.Normal(state, model.θ.r), 0)
end
function SSMProblems.transition_logdensity(
model::BaseModel, prev_state, current_state, step
)
return logpdf(Normal(model.θ.a * prev_state, model.θ.q), current_state)
end

AdvancedPS.isdone(::BaseModel, step) = step > 3

Expand Down Expand Up @@ -80,7 +90,7 @@
seed = 10
rng = Random.MersenneTwister(seed)

for sampler in [AdvancedPS.PGAS(10)]
for sampler in [AdvancedPS.PGAS(10), AdvancedPS.PG(10)]
Random.seed!(rng, seed)
chain1 = sample(rng, model, sampler, 10)
vals1 = hcat([chain.trajectory.model.X for chain in chain1]...)
Expand Down

0 comments on commit fb5c8ad

Please sign in to comment.