In [None]:
parent_dir = joinpath(@__DIR__, "..")
include(joinpath(parent_dir, "helper_split.jl"))
include(joinpath(parent_dir, "algorithms.jl"))
include(joinpath(parent_dir, "helper_tt.jl"))

using Distributions
using ForwardDiff
using LinearAlgebra
using Plots

In [None]:
dim = 1
nr_mixtures = 2
λ = 1/nr_mixtures
μ = Vector{Vector{Float64}}(undef,nr_mixtures)
μ[1] =  10*ones(dim)
μ[2] = -10*ones(dim)
μ /= 2
correlation = 0
Σ = 1.0 * Matrix(I,dim,dim)
Σ_inv = inv(Σ)
# mixtures = [gaussian_pdf(i,Σ, Σ_inv) for i ∈ μ]
mixtures = [gaussian_pdf(μ[i],Σ, Σ_inv) for i in eachindex(μ)]
# mixtures = [gaussian_pdf(μ[i],Σ, Σ_inv) for i = 1:nr_mixtures]

function target(x::Vector)
    fn = 0
    for i = 1:nr_mixtures
        fn = fn + λ * mixtures[i](x)
    end
    fn
end
# target(x::Vector) = λ * mixtures[1](x) + λ * mixtures[2](x)
U(x)  = -log(target(x))
∇U(x) = ForwardDiff.gradient(U,x)
m = minimum(μ)[]
M = maximum(μ)[]
Hessian_bound = (1 + 0.25 * (M-m)^2)
α = 0.7
Hessian_bound_alpha = ((1-α) * (1 + 0.25 * (M-m)^2)) * Matrix(I,dim,dim)
s(x)  = exp(α*U(x))
∇U_s(x) = (1-α)*∇U(x)
unnorm_target_s(x) = s(x) * exp(-U(x))

In [None]:
T = 2 * 10^2
x_init = μ[1]
v_init = rand((-1,1),dim)
skele = ZigZag(∇U_s, Hessian_bound_alpha, T, x_init,v_init)


pos = [skele[i].position for i in eachindex(skele)]
times = [skele[i].time for i in eachindex(skele)]
p_std = hline(μ[2], lc=:grey,label = "", lw=0.5)
hline!(μ[1], lc=:grey,label = "", lw=0.5)
plot!(times,vcat(pos...), label = "", lc = "red",
                lw = 1.5, grid=:none, xlabel = "time", ylabel = "position",
                xlims = [0,1.02 * times[end]], ylims = [-22,22])
display(p_std)

In [None]:
delta = 1e-1
discretised = discretise_from_skeleton(skele,delta)
# plot!(range(0, step=delta, length=length(discretised)),vcat(discretised...))

chain_tt = approximate_timechanged_skele(discretised, s, delta)
times_tt = [chain_tt[n].time for n in eachindex(chain_tt)]
positions = [chain_tt[n].position for n in eachindex(chain_tt)]
speeds = [chain_tt[n].speed for n in eachindex(chain_tt)]

p_tt = hline(μ[2], lc=:grey,label = "", lw=0.5)
hline!(μ[1], lc=:grey,label = "", lw=0.5)
plot!(times_tt, vcat(positions...), label = "", lc = "red",
                lw = 1.5, grid=:none, xlabel = "time", ylabel = "position",
                xlims = [0,1.02 * times_tt[end]], ylims = [-22,22])
display(p_tt)