Skip to content

Commit

Permalink
Add jumps to particle state
Browse files Browse the repository at this point in the history
  • Loading branch information
THargreaves committed Jan 15, 2024
1 parent 2902fa6 commit 76ff740
Showing 1 changed file with 94 additions and 16 deletions.
110 changes: 94 additions & 16 deletions examples/levy-ssm/gamma_process.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ Ns = 10
f(dt, θ) = exp* dt)
function Base.exp(dyn::LangevinDynamics, dt::Real)
let θ = dyn.θ
return [1.0 (f(dt, θ) - 1)/θ; 0 f(dt, θ)]
f_val = f(dt, θ)
return [1.0 (f_val - 1)/θ; 0 f_val]
end
end

Expand All @@ -125,7 +126,7 @@ for (i, t) in enumerate(ts)
dt = t - s
path = simulate(process, dt, s, t, ϵ)
μ, Σ = meancov(t, dyn, path, nvm)
X[i, :] = rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
X[i, :] .= rand(MultivariateNormal(exp(dyn, dt) * X[i - 1, :] + μ, Σ))
end

let H = dyn.H, σe = dyn.σe
Expand All @@ -141,36 +142,113 @@ Parameters = @NamedTuple begin
times::Vector{Float64} # Ugly, but avoids global de-ref
end

struct MixedState{T}
x::Vector{T}
path::GammaPath{T}
end

mutable struct LevyLangevin <: AdvancedPS.AbstractStateSpaceModel
X::Vector{Vector{Float64}}
X::Vector{MixedState{Float64}}
θ::Parameters
LevyLangevin::Parameters) = new(Vector{Array{2,Float64}}(), θ)
LevyLangevin::Parameters) = new(Vector{MixedState{Float64}}(), θ)
end

struct InitialDistribution end
function Base.rand(rng::Random.AbstractRNG, ::InitialDistribution)
return MixedState(rand(MultivariateNormal([0, 0], I)), GammaPath(Float64[], Float64[]))
end

struct TransitionDistribution{T}
current_state::MixedState{T}
model::LevyLangevin
s::T
t::T
end

function Base.rand(rng::Random.AbstractRNG, td::TransitionDistribution)
let model = td.model, s = td.s, t = td.t, state = td.current_state
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
Σ += 1e-6 * I
return MixedState(rand(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ)), path)
end
end

# Required for ancestor sampling in PGAS
function Distributions.logpdf(td::TransitionDistribution, state::MixedState)
let model = td.model, s = td.s, t = td.t, state = td.current_state
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
Σ += 1e-6 * I
return logpdf(MultivariateNormal(exp(dyn, dt) * state.x + μ, Σ), state.x)
end
end

θ₀ = Parameters((dyn, process, nvm, ts))

AdvancedPS.initialization(model::LevyLangevin) = MultivariateNormal([0, 0], I)
function AdvancedPS.transition(model::LevyLangevin, state, step)
function AdvancedPS.initialization(::LevyLangevin)
return InitialDistribution()
end
function AdvancedPS.transition(model::LevyLangevin, state::MixedState, step)
times = model.θ.times
s = times[step - 1]
t = times[step]
dt = t - s
path = simulate(model.θ.process, dt, s, t, ϵ)
μ, Σ = meancov(t, model.θ.dyn, path, model.θ.nvm)
return MultivariateNormal(exp(dyn, dt) * state + μ, Σ)
return TransitionDistribution(state, model, s, t)
end

function AdvancedPS.observation(model::LevyLangevin, state, step)
return logpdf(Normal(transpose(H) * state, σe), Y[step])
function AdvancedPS.observation(model::LevyLangevin, state::MixedState, step)
return logpdf(Normal(transpose(H) * state.x, σe), Y[step])
end
AdvancedPS.isdone(::LevyLangevin, step) = step > length(ts)

model = LevyLangevin(θ₀)
pg = AdvancedPS.PG(Np, 1.0)
pg = AdvancedPS.PG(Np)
chains = sample(rng, model, pg, Ns; progress=true);

particles = hcat([chain.trajectory.model.X for chain in chains]...) # Concat all sampled states
mean_trajectory = transpose(hcat(mean(particles; dims=2)...))
marginal_states = map(s -> s.x, particles);
jump_times = map(s -> s.path.times, particles);
jump_intensities = map(s -> s.path.jumps, particles);

# Plot marginal state and jump intensities for one trajectory
p1 = plot(
ts,
[state[1] for state in marginal_states[:, end]];
color=:darkorange,
label="Marginal State (x1)",
)
plot!(
ts,
[state[2] for state in marginal_states[:, end]];
color=:dodgerblue,
label="Marginal State (x2)",
)

plot(X; color=:darkorange, label="Original Trajectory")
plot!(mean_trajectory; color=:dodgerblue, label="Mean trajectory", opacity=0.9)
p2 = scatter(
vcat([t for t in jump_times[:, end]]...),
vcat([j for j in jump_intensities[:, end]]...);
color=:darkorange,
label="Jumps",
)

plot(
p1, p2; plot_title="Marginal State and Jump Intensities", layout=(2, 1), size=(600, 600)
)

# Plot mean trajectory with standard deviation
mean_trajectory = transpose(hcat(mean(marginal_states; dims=2)...))
std_trajectory = dropdims(std(stack(marginal_states); dims=3); dims=3)

plot(
mean_trajectory;
ribbon=std_trajectory',
color=:darkorange,
label="Mean trajectory",
opacity=0.3,
title="Inference Quality",
)
plot!(
mean_trajectory; color=:dodgerblue, label="Original Trajectory", title="Path Degeneracy"
)

0 comments on commit 76ff740

Please sign in to comment.