# Plots
For the linear uncertainty solver paper

In [None]:
using Plots

using Measurements
using Distributions
using LaTeXStrings

In [None]:
# uncomment and run this cell for publication quality plots, which however take more time to build and more ram. 
# It will also require a latex installation with some relevant packages.
# pgfplotsx()

## Linear Prototype Model

In [None]:
# define some initial variables

μₓ = 0.5
# Σₓ = 0.2

# GP approx
a = 2.
b = 0.1

### Linear model

In [None]:
x0 = 0.3
Σ0 = 0.05
bs = range(-2.5*b, 2.5*b, length = 5)
x0s = range(x0-2*Σ0, x0+2*Σ0, length = 3)

line(x, b) = -a * x + b
sol(x0, b, t) = exp(-a*t)*x0 + b/a * (1-exp(-a*t))

xs = range(-0.2, 0.2, length = 100)
ts = range(0., 2., length = 100)

cls = Plots.Colors.colormap("blues", 8)[3:end]

p1 = plot(; xlabel = "x", ylabel = "f(x)", legend = :topright)
p2 = plot(; xlabel = "time t", ylabel = "x", legend = :topright)

for (i, b) in enumerate(bs)
    plot!(p1, xs, line.(xs, b); color = cls[i], label = "b=$b")
    for x0 in x0s
        plot!(p2, ts, sol.(x0, b, ts); color = cls[i], label = "")
    end
end

dist_lhwd = 2.2
# distribution of lines in model plot
anchor = -0.05
b_dist = Normal(-anchor*a,b)
dxs = range(-anchor*a-3.5*b, -anchor*a+3.5*b, length = 40)
plot!(p1, anchor .+ pdf.(b_dist, dxs)*0.01, dxs; 
    label = L"\mathcal{N}(-a x, \beta)" , color = cls[end], linewidth = dist_lhwd)
plot!(p1, anchor .* ones(2), [-anchor*a-4*b, -anchor*a+4*b];
    color = :black, label = "")

# trajectory distributions
## initial
x0_dist = Normal(x0, Σ0)
x0ds = range(x0-3*Σ0, x0+3*Σ0, length = 70)
plot!(p2, pdf.(x0_dist, x0ds)*0.016, x0ds; 
    linewidth = dist_lhwd, label = L"\mathcal{N}(\mu_0, \Sigma_0)")
plot!(p2, [0., 0.], [x0-3*Σ0, x0+3*Σ0]; color = :black, label = "")
## end
Σ2 = b/(a)
x2_dist = Normal(0, Σ2)
x2ds = range(-3*Σ2, 3*Σ2, length = 70)
plot!(p2, 2 .+ pdf.(x2_dist, x2ds)*0.016, x2ds; 
    linewidth = dist_lhwd, label = L"\mathcal{N}(0, \beta/a^2)")
plot!(p2, [2., 2.], [-3*Σ2, 3*Σ2]; color = :black, label = "")


lin_plot = plot(p1, p2; layout = (2, 1), size = (450, 350), bottom_margin = -6Plots.Measures.mm)

In [None]:
ptmp = plot(lin_plot, size = (450, 350))
# uncomment to save plot
# savefig(ptmp, "proto_lin_model.pdf")

### Moment-Matching extended to Euler (Bad Euler)

In [None]:
β = -.5
h = 0.01;
tspan = [0., 2.]
μₓ = 0.5
Σₓ = 0.002
x0 = [μₓ] #± Σₓ

# generate maximally distinguishable colors
cols = distinguishable_colors(6, [RGB(1,1,1), RGB(0,0,0)], dropseed=true)

In [None]:
using DifferentialEquations
using DifferentialEquations.EnsembleAnalysis

f(x, p, t) = -a.*x .+ p[1]
prob = ODEProblem(f, x0, tspan, [β])

function prob_func(prob, i, repeat)
    @. prob.p = randn() * sqrt(b)
    @. prob.u0 = μₓ + randn()*sqrt(Σₓ)
    prob
end

let 
    ens_prob = EnsembleProblem(prob; prob_func)
    ens_sol = solve(ens_prob; trajectories = 500000);

    global eval_tpoints = range(tspan[1], tspan[2], length = 80)
    full_meanvars = timepoint_meanvar.(Ref(ens_sol),eval_tpoints);
    global sample_mean = reduce(vcat, getindex.(full_meanvars, 1))
    global sample_std = sqrt.(reduce(vcat, getindex.(full_meanvars, 2)));
end;

let 
    ens_prob = EnsembleProblem(prob; prob_func)
    ens_sol = solve(ens_prob; trajectories = 50000);

    eval_tpoints = range(tspan[1], tspan[2], length = 80)
    full_meanvars = timepoint_meanvar.(Ref(ens_sol),eval_tpoints);
    sample_mean = reduce(vcat, getindex.(full_meanvars, 1))
    global sample_std_low = sqrt.(reduce(vcat, getindex.(full_meanvars, 2)));
end;

In [None]:
v(Σ, a, h) = Σ .+ h^2*b .+ h^2*a^2*Σ .- 2*h*a*Σ
vs(Σ, h) = v(Σ, a, h)

function badeulerstep(x, h)
    m = (1 - a*h) * x.val
    v = vs(x.err^2, h)
    m ± sqrt(v)
end

function badeuler(x0, tspan, h)
    nsteps = ceil(Int, diff(tspan)[1]/h)

    xe = zeros(Measurement{Float64}, nsteps+1)
# 
    xe[1] = x0 # μ₀ ± sqrt(Σ₀)

    for i in 1:nsteps
        xe[i+1] = badeulerstep(xe[i], h)
    end
    th = tspan[1]:h:h*nsteps

    th, xe
end

#### Step-size dependence

In [None]:
# plot for the paper
x0 = μₓ ± sqrt(Σₓ)

lbls = ["2h", "h", "h/2"]
p1 = plot(; xlabel = "time t", ylabel = "mean error")
p2 = plot(; legend = :right, xlabel = "time t", ylabel = "variance", left_margin = -5Plots.Measures.mm)
an_mean_sol(x0, t) = exp(-a*t) * x0 

lstyles = [:solid, :dash, :dot]
plot!(p1, eval_tpoints[[1, end]], [0., 0.]; color = :black, label = "")
plot!(p2, eval_tpoints, sample_std.^2; color = :black, label = "sampling")
for (i, h) in enumerate([h*2, h, h/2])
    be_sol = badeuler(x0, tspan, h)
    # mean error plot
    err = abs.(getfield.(be_sol[2], :val) .- an_mean_sol.(μₓ, be_sol[1]))
    plot!(p1, be_sol[1], err; label = "", color = cols[2], linestyle = lstyles[i])
    # var plot
    plot!(p2, be_sol[1], getfield.(be_sol[2], :err).^2;
        label = lbls[i], color = cols[2], linestyle = lstyles[i])
end
be_plot = plot(p1, p2)

In [None]:
# for the poster
using Interpolations

x0 = μₓ ± sqrt(Σₓ)

lbls = ["2h", "h", "h/2"]
p1 = plot(; xlabel = "time t", ylabel = "error of the mean")
p2 = plot(; legend = :right, xlabel = "time t", ylabel = "std. deviation", left_margin = -5Plots.Measures.mm)
an_mean_sol(x0, t) = exp(-a*t) * x0 

lwdh = 2.0
lstyles = [:solid, :dash, :dot]
plot!(p1, eval_tpoints[[1, end]], [0., 0.]; color = :black, label = "", linewidth = lwdh)
plot!(p2, eval_tpoints, sample_std.^1; color = :black, label = "sampling")
# plot!(p2, eval_tpoints[[1, end]], [0., 0.]; color = :black, label = "sampling", linewidth = lwdh)

sitp = linear_interpolation(eval_tpoints, sample_std)

for (i, h) in enumerate([h*2, h, h/2])
    be_sol = badeuler(x0, tspan, h)
    # mean error plot
    err = abs.(getfield.(be_sol[2], :val) .- an_mean_sol.(μₓ, be_sol[1]))
    plot!(p1, be_sol[1], err; label = "", color = cols[2], linestyle = lstyles[i], linewidth = lwdh)
    # var plot
    # err = abs.(getfield.(be_sol[2], :err).^1 .- sitp.(be_sol[1]).^1)
    err = getfield.(be_sol[2], :err)
    plot!(p2, be_sol[1], err;
        label = lbls[i], color = cols[2], linestyle = lstyles[i], linewidth = lwdh)
end
be_plot = plot(p1, p2; size = (1000, 330).*0.65, left_margin = 3Plots.Measures.mm, bottom_margin = 4Plots.Measures.mm)

# savefig(be_plot, "be_plot.pdf")

be_plot

In [None]:
ptmp = plot(be_plot, size = (650, 450))
# savefig(ptmp, "step_size_dep.pdf")

#### Restarts

In [None]:
ninterval = 3
interval_edges = range(tspan[1], tspan[end], length = ninterval+1)
μin = μₓ
Sin = Σₓ
x0 = [μₓ]

nsims = 30000
vbeta = randn(nsims) * sqrt(b)

transient_labels = ["", "", "trans."]
asym_labels = ["", "", "asym."]
full_labels = ["", "", "sampl."]

p3 = plot(; xlabel = "time t", ylabel = "variance", legend = :topleft, ylim = (0., 0.017))

for i in 1:ninterval
    tsp = interval_edges[i:i+1]
    prob = ODEProblem(f, x0, tsp, [β])

    function prob_func(prob, i, repeat)
        @. prob.p = vbeta[i]
        @. prob.u0 = μin + randn()*sqrt(Sin)
        prob
    end
    
    ens_prob = EnsembleProblem(prob; prob_func)

    ens_sol = solve(ens_prob; trajectories = nsims);

    eval_itpoints = range(tsp[1], tsp[2], length = ceil(Int, 80/ninterval))
    meanvars = timepoint_meanvar.(Ref(ens_sol),eval_itpoints);
    means = reduce(vcat, getindex.(meanvars, 1))
    vars = reduce(vcat, getindex.(meanvars, 2))

    dts = range(0., tsp[2]-tsp[1], length = length(eval_itpoints))
    transient(t) = exp(-2*a*t)*Sin
    asym(t) = b/(a^2)*(1-exp(-a*t))^2
    plot!(p3, eval_itpoints, vars; 
        color = :black, label = full_labels[i], linewidth = 1.8)
    plot!(p3, eval_itpoints, transient.(dts); 
        color = cols[4], label = transient_labels[i])
    plot!(p3, eval_itpoints, asym.(dts); 
        color = cols[5], label = asym_labels[i])
    
    μin = means[end] 
    Sin = vars[end]
end

plot!(p3, interval_edges[2]*ones(2),  [0., 0.015];
    color = :grey60, linestyle = :dash, label = "restart")
plot!(p3, interval_edges[3]*ones(2),  [0., 0.015];
    color = :grey60, linestyle = :dash, label = "")

prs = deepcopy(p3)

In [None]:
# for the poster
using LaTeXStrings

ninterval = 3
interval_edges = range(tspan[1], tspan[end], length = ninterval+1)
μin = μₓ
Sin = Σₓ
x0 = [μₓ]

transient(t) = exp(-2*a*t)*Sin
asym(t) = b/(a^2)*(1-exp(-a*t))^2

nsims = 60000
vbeta = randn(nsims) * sqrt(b)

transient_labels = ["", "", "transient"]
asym_labels = ["", "", "asymptotic"]
full_labels = ["sample from X₀", "sample from X₁", "sample from X₂"]

p3 = plot(; xlabel = "time t", ylabel = "variance", legend = :topleft, ylim = (0., 0.025), size = (1000, 480).*0.65)

prob = ODEProblem(f, x0, tspan, [β])
function prob_func(prob, i, repeat)
    @. prob.p = vbeta[i]
    @. prob.u0 = μin + randn()*sqrt(Sin)
    prob
end
eval_itpoints = range(tspan[1], tspan[2], length = ceil(Int, 240/ninterval))
ens_prob = EnsembleProblem(prob; prob_func)
ens_sol = solve(ens_prob; trajectories = nsims);
meanvars = timepoint_meanvar.(Ref(ens_sol),eval_itpoints);
vars = reduce(vcat, getindex.(meanvars, 2))

plot!(p3, eval_itpoints, vars; 
    color = :black, label = "full sample", linewidth = 1.8, linestyle = :dash)
    plot!(p3, eval_itpoints, transient.(eval_itpoints); 
    color = cols[4], label = "", linewidth = 1.8, linestyle = :dash)
plot!(p3, eval_itpoints, asym.(eval_itpoints); 
    color = cols[5], label = "", linewidth = 1.8, linestyle = :dash)

plot!(p3, interval_edges[2]*ones(2),  [0., 0.0145];
    color = :grey70, linestyle = :dot, label = "restart", linewidth = 2.3)
plot!(p3, interval_edges[3]*ones(2),  [0., 0.015];
    color = :grey70, linestyle = :dot, label = "", linewidth = 2.3)

greys = [:grey50, :grey40, :grey30]
endp = []
endv = []
for i in 1:ninterval
    tsp = interval_edges[i:i+1]
    prob = ODEProblem(f, x0, tsp, [β])

    function prob_func(prob, i, repeat)
        @. prob.p = vbeta[i]
        @. prob.u0 = μin + randn()*sqrt(Sin)
        prob
    end
    
    ens_prob = EnsembleProblem(prob; prob_func)

    ens_sol = solve(ens_prob; trajectories = nsims);

    eval_itpoints = range(tsp[1], tsp[2], length = ceil(Int, 80/ninterval))
    meanvars = timepoint_meanvar.(Ref(ens_sol),eval_itpoints);
    means = reduce(vcat, getindex.(meanvars, 1))
    vars = reduce(vcat, getindex.(meanvars, 2))

    dts = range(0., tsp[2]-tsp[1], length = length(eval_itpoints))
    
    plot!(p3, eval_itpoints, vars; 
        color = greys[i], label = full_labels[i], linewidth = 1.8)
    plot!(p3, eval_itpoints, transient.(dts); 
        color = cols[4], label = transient_labels[i])
    plot!(p3, eval_itpoints, asym.(dts); 
        color = cols[5], label = asym_labels[i])
    
    push!(endp, eval_itpoints[end])
    push!(endv, vars[end])
    
    μin = means[end] 
    Sin = vars[end]
end

scatter!(p3, endp, endv; 
        color = :orange, label = "", markersize = 3.2)

plot!(p3, xticks = ([0., 0.5, 2/3, 1., 4/3, 1.5, 2.], ["0.0", "0.5", "\$t_1\$", "1.0", "\$t_2\$", "1.5", "2"]); size = (1000, 480).*0.65)
annotate!([2/3, 4/3], endv[1:2] .+ 0.002, [L"$\Sigma_1$", L"$\Sigma_2$"])

p3

In [None]:
# savefig(p3, "central_plot.pdf")

In [None]:
middle_plot = plot(be_plot, prs; layout = (2, 1))

## Corrected solvers

In [None]:
# iterated flow
itervar(S, n) = exp(-2*a*h)*S + b/a^2*(1 - exp(-a*h))^2 + 2*b/a^2 *(1 - exp(-a*h))*exp(-a*h)*(1 - exp(-a*h*n))

# Euler steps on the linear model
fs(n, h) = sum((1-a*h).^(n-1:-1:0))
eulervar(S, n, h) = (1-a*h)^2 * S + h^2*b + 2*fs(n, h)*h^2*b*(1-a*h)

function euler(Σₓ, tspan, h)
    timesteps = collect(range(tspan[1], tspan[2], step = h))
    eulervarsteps = zeros(length(timesteps))
    eulervarsteps[1] = Σₓ

    for i in 1:length(timesteps)-1
        eulervarsteps[i+1] = eulervar(eulervarsteps[i], i, h)
    end
    timesteps, eulervarsteps
end

# analytical solution
ana_mean(t) = exp(-a*t)*μₓ
ana_var(t) = exp(-2*a*t)*Σₓ + b/(a^2)*(1-exp(-a*t))^2

In [None]:
# corrected Euler
ts1, ev1 = euler(Σₓ, tspan, h)
ts2, ev2 = euler(Σₓ, tspan, h/2)
ts4, ev4 = euler(Σₓ, tspan, h/4)

# iterated flow
timesteps = collect(range(tspan[1], tspan[2], step = h))
varsteps = zeros(length(timesteps))
varsteps[1] = Σₓ

for i in 1:length(timesteps)-1
    varsteps[i+1] = itervar(varsteps[i], i)
end

In [None]:
p = plot(; legend = :topleft, xlabel = "time t", ylabel = "variance")
plot!(p, eval_tpoints, ana_var.(eval_tpoints);  
    label = "analytic", color = cols[1])
plot!(p, eval_tpoints, sample_std.^2; 
    label = "sampling, 500k", color = :black, linestyle = :dash)
plot!(p, eval_tpoints, sample_std_low.^2; 
    label = "sampling, 50k", color = :black)
plot!(p, timesteps, varsteps; 
    label = "flow steps", color = cols[3])
plot!(p, ts1, ev1; 
    label = "corr. euler, h", color = cols[4])
plot!(p, ts2, ev2; 
    label = "corr. euler, h/2", color = cols[4], linestyle = :dash)
plot!(p, ts4, ev4; 
    label = "corr. euler, h/4", color = cols[4], linestyle = :dot)

# pin = plot()
err = abs.(sample_std.^2 .- ana_var.(eval_tpoints)) ./ana_var.(eval_tpoints)
plot!(p, eval_tpoints, err; inset = (1, bbox(0.43,0.37,0.55,0.57)), subplot = 2,
    label = "", color = :black, linestyle = :dash, ylabel = "error") # , xlabel = "time t",
plot!(p, timesteps, zeros(length(timesteps)); subplot = 2,
    label = "", color = cols[1])
err = abs.(sample_std_low.^2 .- ana_var.(eval_tpoints)) ./ana_var.(eval_tpoints)
plot!(p, eval_tpoints, err, label = "", color = :black, subplot = 2)
err = abs.(varsteps .- ana_var.(timesteps)) ./ ana_var.(timesteps)
plot!(p, timesteps, err; subplot = 2,
    label = "", color = cols[3])
err = abs.(ev1 .- ana_var.(ts1)) ./ ana_var.(ts1)
plot!(p, ts1, err; subplot = 2,
    label = "", color = cols[4])
err = abs.(ev2 .- ana_var.(ts2)) ./ ana_var.(ts2)
plot!(p, ts2, err; subplot = 2,
    label = "", color = cols[4], linestyle = :dash)
err = abs.(ev4 .- ana_var.(ts4)) ./ ana_var.(ts4)
plot!(p, ts4, err; subplot = 2, 
    label = "", color = cols[4], linestyle = :dot)

p

### Paper plot

In [None]:
pl = plot(plot(p1; right_margin = -3Plots.Measures.mm ) ,p2; layout = (2,1))
ptmp2 = plot(pl, p3, p; layout = (1,3), size = (1000, 350), 
left_margin = -1Plots.Measures.mm, bottom_margin = -4Plots.Measures.mm)

In [None]:
# savefig(ptmp2, "lin_model_solvers.pdf")

## Complete plot

In [None]:
p_final = plot(lin_plot, middle_plot, p; layout = (1,3), size = (1000, 385), 
    left_margin = 0Plots.Measures.mm, bottom_margin = -4Plots.Measures.mm)

In [None]:
# savefig(p_final, "lin_plot.pdf")