Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jumps to particle state #96

Merged
merged 1 commit into from
Feb 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 94 additions & 16 deletions examples/levy-ssm/gamma_process.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ Ns = 10
f(dt, θ) = exp(θ * dt)
function Base.exp(dyn::LangevinDynamics, dt::Real)
let θ = dyn.θ
return [1.0 (f(dt, θ) - 1)/θ; 0 f(dt, θ)]
f_val = f(dt, θ)
return [1.0 (f_val - 1)/θ; 0 f_val]
end
end

Expand All @@ -125,7 +126,7 @@ for (i, t) in enumerate(ts)
dt = t - s
path = simulate(process, dt, s, t, ϵ)
μ, Σ = meancov(t, dyn, path, nvm)
X[i, :] = rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
X[i, :] .= rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
end

let H = dyn.H, σe = dyn.σe
Expand All @@ -141,36 +142,113 @@ Parameters = @NamedTuple begin
times::Vector{Float64} # Ugly, but avoids global de-ref
end

struct MixedState{T}
x::Vector{T}
path::GammaPath{T}
end

mutable struct LevyLangevin <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Vector{Float64}}
X::Vector{MixedState{Float64}}
θ::Parameters
LevyLangevin(θ::Parameters) = new(Vector{Array{2,Float64}}(), θ)
LevyLangevin(θ::Parameters) = new(Vector{MixedState{Float64}}(), θ)
end

struct InitialDistribution end
function Base.rand(rng::Random.AbstractRNG, ::InitialDistribution)
return MixedState(rand(MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[]))
end

struct TransitionDistribution{T}
current_state::MixedState{T}
model::LevyLangevin
s::T
t::T
end

function Base.rand(rng::Random.AbstractRNG, td::TransitionDistribution)
let model = td.model, s = td.s, t = td.t, state = td.current_state
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
Σ += 1e-6 * I
return MixedState(rand(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path)
end
end

# Required for ancestor sampling in PGAS
function Distributions.logpdf(td::TransitionDistribution, state::MixedState)
let model = td.model, s = td.s, t = td.t, state = td.current_state
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
Σ += 1e-6 * I
return logpdf(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ), state.x)
end
end

θ₀ = Parameters((dyn, process, nvm, ts))

AdvancedPS.initialization(model::LevyLangevin) = MultivariateNormal([0, 0], I)
function AdvancedPS.transition(model::LevyLangevin, state, step)
function AdvancedPS.initialization(::LevyLangevin)
return InitialDistribution()
end
function AdvancedPS.transition(model::LevyLangevin, state::MixedState, step)
times = model.θ.times
s = times[step - 1]
t = times[step]
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
return MultivariateNormal(exp(dyn, dt) * state + μ, Σ)
return TransitionDistribution(state, model, s, t)
end

function AdvancedPS.observation(model::LevyLangevin, state, step)
return logpdf(Normal(transpose(H) * state, σe), Y[step])
function AdvancedPS.observation(model::LevyLangevin, state::MixedState, step)
return logpdf(Normal(transpose(H) * state.x, σe), Y[step])
end
AdvancedPS.isdone(::LevyLangevin, step) = step > length(ts)

model = LevyLangevin(θ₀)
pg = AdvancedPS.PG(Np, 1.0)
pg = AdvancedPS.PG(Np)
chains = sample(rng, model, pg, Ns; progress=true);

particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = transpose(hcat(mean(particles; dims=2)...))
marginal_states = map(s -> s.x, particles);
jump_times = map(s -> s.path.times, particles);
jump_intensities = map(s -> s.path.jumps, particles);

# Plot marginal state and jump intensities for one trajectory
p1 = plot(
ts,
[state[1] for state in marginal_states[:, end]];
color=:darkorange,
label="Marginal State (x1)",
)
plot!(
ts,
[state[2] for state in marginal_states[:, end]];
color=:dodgerblue,
label="Marginal State (x2)",
)

plot(X; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
p2 = scatter(
vcat([t for t in jump_times[:, end]]...),
vcat([j for j in jump_intensities[:, end]]...);
color=:darkorange,
label="Jumps",
)

plot(
p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600)
)

# Plot mean trajectory with standard deviation
mean_trajectory = transpose(hcat(mean(marginal_states; dims=2)...))
std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3)

plot(
mean_trajectory;
ribbon=std_trajectory',
color=:darkorange,
label="Mean trajectory",
opacity=0.3,
title="Inference Quality",
)
plot!(
mean_trajectory; color=:dodgerblue, label="Original Trajectory", title="Path Degeneracy"
)
Loading