Example using Lotka Volterra from https://diffeqflux.sciml.ai/stable/examples/optimization_ode/

In [None]:
using DifferentialEquations, Flux, Optim, DiffEqFlux, DiffEqSensitivity, Plots

In [None]:
function lotka_volterra!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

# Initial condition
u0 = [1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

# LV equation parameter. p = [α, β, δ, γ]
p = [1.5, 1.0, 3.0, 1.0]

# Setup the ODE problem, then solve
prob = ODEProblem(lotka_volterra!, u0, tspan, p)
sol = solve(prob, Tsit5())


In [None]:
# Plot the solution
using Plotly
plot(sol)
# savefig("LV_ode.png")



In [None]:
function loss(p)
  sol = solve(prob, Tsit5(), p=p, saveat = tsteps)
  loss = sum(abs2, sol.-2)
  return loss, sol
end

callback = function (p, l, pred)
  # display(l)
  plt = plot(pred, ylim = (0, 6))
  # display(plt)
  # Tell sciml_train to not halt the optimization. If return true, then
  # optimization stops.
  return false
end

result_ode = DiffEqFlux.sciml_train(loss, p,
                                    ADAM(0.1),
                                    cb = callback,
                                    maxiters = 100)

In [None]:
result_ode
prob2 = ODEProblem(lotka_volterra!, u0, tspan, result_ode)
sol2 = solve(prob2, Tsit5())
plot(sol2)


Modify it to include initial conditions in optimization.

In [None]:
prob = ODEProblem(lotka_volterra!, u0, tspan, result_ode)
function loss(p1)
  u0 = p1[1:2] 
    p = p1[3:end]
  prob = ODEProblem(lotka_volterra!, u0, tspan, result_ode)
  sol = solve(prob, Tsit5(), p=p, saveat = tsteps)
  loss = sum(abs2, sol.-2)
  return loss, sol
end

In [None]:
result_ode = DiffEqFlux.sciml_train(loss, [1.0, 1.0, 1.5, 1.0, 3.0, 1.0],
                                    ADAM(0.1),
                                    maxiters = 100)

Show solution

In [None]:
 prob = ODEProblem(lotka_volterra!, result_ode[1:2], tspan, result_ode)
sol = solve(prob, Tsit5(), p=result_ode[3:end], saveat = tsteps)
plot(sol)

This is a nice optimization method for finding parameters that solve an ODE.
But just get point estimate; no distribution.

See Bayesian methods
https://turing.ml/dev/tutorials/10-bayesiandiffeq/
https://github.com/TuringLang/TuringTutorials/blob/master/10_diffeq.ipynb


In [None]:
using Turing, Distributions, DifferentialEquations 

# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics.
using MCMCChains, Plots, StatsPlots

# Set a seed for reproducibility.
using Random
Random.seed!(14);

# Disable Turing's progress meter for this tutorial.
# Turing.turnprogress(false)

using Logging
Logging.disable_logging(Logging.Warn)
# LogLevel(1001)


In [None]:
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, γ, δ  = p
  du[1] = (α - β*y)x # dx =
  du[2] = (δ*x - γ)y # dy = 
end
p = [1.5, 1.0, 3.0, 1.0]
u0 = [1.0,1.0]
prob1 = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob1,Tsit5())
plot(sol)

In [None]:
sol1 = solve(prob1,Tsit5(),saveat=0.1)
odedata = Array(sol1) + 0.8 * randn(size(Array(sol1)))
plot(sol1, alpha = 0.3, legend = false); Plots.scatter!(sol1.t, odedata')

In [None]:
Turing.setadbackend(:forwarddiff)

@model function fitlv(data, prob1)
    σ ~ InverseGamma(2, 3) # ~ is the tilde character
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = remake(prob1, p=p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ)
    end
end

model = fitlv(odedata, prob1)

# This next command runs 3 independent chains without using multithreading. 
chain = mapreduce(c -> sample(model, NUTS(.65),1000), chainscat, 1:3)

In [None]:
plot(chain)
pl = Plots.scatter(sol1.t, odedata');

In [None]:
chain_array = Array(chain)
for k in 1:300 
    resol = solve(remake(prob1,p=chain_array[rand(1:1500), 1:4]),Tsit5(),saveat=0.1)
    plot!(resol, alpha=0.1, color = "#BBBBBB", legend = false)
end
# display(pl)
plot!(sol1, w=1, legend = false)

Missing predator data


In [None]:
@model function fitlv2(data, prob1) # data should be a Vector
    σ ~ InverseGamma(2, 3) # ~ is the tilde character
    α ~ truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ truncated(Normal(1.2,0.5),0,2)
    γ ~ truncated(Normal(3.0,0.5),1,4)
    δ ~ truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = remake(prob1, p=p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[i] ~ Normal(predicted[i][2], σ) # predicted[i][2] is the data for y - a scalar, so we use Normal instead of MvNormal
    end
end

model2 = fitlv2(odedata[2,:], prob1)

In [None]:
Threads.nthreads()


In [None]:
chain2 = sample(model2, NUTS(.45), MCMCThreads(), 5000, 3, progress=false)


In [None]:
pl = Plots.scatter(sol1.t, odedata');
chain_array2 = Array(chain2)
for k in 1:300 
    resol = solve(remake(prob1,p=chain_array2[rand(1:12000), 1:4]),Tsit5(),saveat=0.1) 
    # Note that due to a bug in AxisArray, the variables from the chain will be returned always in
    # the order it is stored in the array, not by the specified order in the call - :α, :β, :γ, :δ
    plot!(resol, alpha=0.1, color = "#BBBBBB", legend = false)
end
#display(pl)
plot!(sol1, w=1, legend = false)