# Bayesian Estimation of Differential Equations

Describe why differential equations are interesting, why they are used, and what you can get out of a Bayesian estimation (parameter estimation with quantified uncertainty).

In [None]:
using Turing, Distributions, DataFrames, DifferentialEquations

# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics.
using MCMCChains, Plots, StatsPlots

# Set a seed for reproducibility.
using Random
Random.seed!(12);

## The Lotka-Volterra Model

Introduce Lotka-Volterra. Use some Latex

In [None]:
function lotka_volterra(du,u,p,t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = (α - β*y)x
  du[2] = dy = (δ*x - γ)y
end
p = [2.2, 1.0, 2.0, 0.4]
u0 = [1.0,1.0]
prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
sol = solve(prob,Tsit5())
plot(sol)

Describe model.

Now describe data generation process and saveat. Point to https://docs.sciml.ai/latest/basics/common_solver_opts/ since a new argument was introduced.

In [None]:
odedata = Array(solve(prob,Tsit5(),saveat=0.1))

## Fitting Lotka-Volterra with DiffEqBayes

Start with the simple high level package. Point to https://github.com/SciML/DiffEqBayes.jl and the documentation on it.

In [None]:
using DiffEqBayes
t = 0:0.1:10.0
priors = [Truncated(Normal(1.5,0.5),0.5,2.5),Truncated(Normal(1.2,0.5),0,2),Truncated(Normal(3.0,0.5),1,4),Truncated(Normal(1.0,0.5),0,2)]
bayesian_result_turing = turing_inference(prob,Tsit5(),t,odedata,priors,num_samples=10_000)

Demonstrate post fit analysis: how do you know the chains are good, and what the results saying? Are they correct?

## Direct Handling of Bayesian Estimation with Turing

Now describe that we can directly use the differential equation solver inside of Turing to allow for more control. The same model is:

In [None]:
Turing.setadbackend(:forwarddiff)

@model function fitlv(data, ::Type{T}=Vector{Float64}) where {T}
    σ ~ InverseGamma(2, 3)
    α ~ Truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ Truncated(Normal(1.2,0.5),0,2)
    γ ~ Truncated(Normal(3.0,0.5),1,4)
    δ ~ Truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    predicted = solve(prob,Tsit5(),saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ[1]*ones(length(predicted[i])))
    end
end

model = fitlv(odedata)
chain = sample(model, NUTS(.65),30)

## Scaling to Large Models: Adjoint Sensitivities

Reference that DifferentialEquations.jl is very efficient for large stiff models. https://docs.sciml.ai/latest/tutorials/advanced_ode_example/ . https://github.com/SciML/DiffEqBenchmarks.jl

Now introduce https://docs.sciml.ai/latest/analysis/sensitivity/ . Describe a bit about pervasive AD in Julia, and how `concrete_solve` plugs into those AD systems to allow for choosing advanced sensitivity analysis (derivative calculation) methods https://docs.sciml.ai/latest/analysis/sensitivity/#Sensitivity-Algorithms-1 . More details on these methods can be found at: https://docs.sciml.ai/latest/extras/sensitivity_math/.

(Mention that `nothing` as the solver makes it use an automatic choice, which normally automatically detects stiffness so is safe in all cases, but this can also be `concrete_solve(prob,Tsit5(),saveat=0.1)` for example)

While these sensitivity analysis methods may seem complicated (and they are!), using them is dead simple. Here is a version of the Lotka-Volterra model with adjoints enabled:

In [None]:
using ReverseDiff
Turing.setadbackend(:reversediff)

@model function fitlv(data, ::Type{T}=Vector{Float64}) where {T}
    σ ~ InverseGamma(2, 3)
    α ~ Truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ Truncated(Normal(1.2,0.5),0,2)
    γ ~ Truncated(Normal(3.0,0.5),1,4)
    δ ~ Truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    predicted = concrete_solve(prob,nothing,saveat=0.1)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ[1]*ones(length(predicted[i])))
    end
end;

model = fitlv(odedata)
chain = sample(model, NUTS(.65),30)

All we had to do is switch the AD backend to one of the adjoint-compatible backends (ReverseDiff, Tracker, or Zygote) and boom the system takes over and we're using adjoint methods! Notice that on this model adjoints are slower. This is because adjoints have a higher overhead on small parameter models and we suggest only using these methods for models with around 100 parameters or more. For more details, see https://arxiv.org/abs/1812.01892 .

Now we can exercise control of the sensitivity analysis method that is used by using the `sensealg` keyword argument. Let's choose the `InterpolatingAdjoint` from https://docs.sciml.ai/latest/analysis/sensitivity/#Sensitivity-Algorithms-1 and enable a compiled ReverseDiff vector-Jacobian product:

In [None]:
@model function fitlv(data, ::Type{T}=Vector{Float64}) where {T}
    σ ~ InverseGamma(2, 3)
    α ~ Truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ Truncated(Normal(1.2,0.5),0,2)
    γ ~ Truncated(Normal(3.0,0.5),1,4)
    δ ~ Truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = ODEProblem(lotka_volterra,u0,(0.0,10.0),p)
    predicted = concrete_solve(prob,nothing,saveat=0.1,
                               sensealg=InterpolatingAdjoint(
                                               autojacvec=ReverseDiffVJP(true)))

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ[1]*ones(length(predicted[i])))
    end
end;
model = fitlv(odedata)
chain = sample(model, NUTS(.65),30)

For more examples of adjoint usage on large parameter models, consult the [DiffEqFlux documentation](https://diffeqflux.sciml.ai/dev/)

## Including Process Noise: Estimation of Stochastic Differential Equations

Describe the SDE model. Use some Latex

In [None]:
function lotka_volterra_noise(du,u,p,t)
    du[1] = p[5]*u[1]
    du[2] = p[6]*u[2]
end
p = [1.5, 1.0, 3.0, 1.0, 0.3, 0.3]
prob = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p)

Solve 3 times to show randomness

In [None]:
sol = solve(prob,saveat=0.01)
plot(sol)
sol = solve(prob,saveat=0.01)
plot(sol)
sol = solve(prob,saveat=0.01)
plot(sol)

Now demonstrate plotting a summary:

In [None]:
sol = solve(EnsembleProblem(prob),saveat=0.01,trajectories=1000)
summ = MonteCarloSummary(sol)
plot(summ)

Get data from the means to fit:

In [None]:
using DiffEqBase.EnsembleAnalysis
averagedata = Array(timeseries_steps_mean(sol))

Now fit the means with Turing:

In [None]:
Turing.setadbackend(:forwarddiff)

@model function fitlv(data, ::Type{T}=Vector{Float64}) where {T}
    σ ~ InverseGamma(2, 3)
    α ~ Truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ Truncated(Normal(1.2,0.5),0,2)
    γ ~ Truncated(Normal(3.0,0.5),1,4)
    δ ~ Truncated(Normal(1.0,0.5),0,2)
    ϕ1 ~ Truncated(Normal(1.2,0.5),0,2)
    ϕ2 ~ Truncated(Normal(1.2,0.5),0,2)

    p = [α,β,γ,δ,ϕ1,ϕ2]
    prob = SDEProblem(lotka_volterra,lotka_volterra_noise,u0,(0.0,10.0),p)
    ensemble_predicted = solve(EnsembleProblem(prob),saveat=0.01,trajectories=1000)
    predicted_means = timeseries_steps_mean(ensemble_predicted)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted_means[i], σ[1]*ones(length(predicted_means[i])))
    end
end;

model = fitlv(odedata)
chain = sample(model, NUTS(.65),30)

## Delayed Interaction Models: Estimation of Delay Differential Equations

Explain the importance of delay differential equations. Show some Latex

In [None]:
function delay_lotka_volterra(du,u,h,p,t)
  x, y = u
  α,β,γ,δ = p
  du[1] = α*h(p,t-1)[1] - β*x*y
  du[2] = -γ*y + δ*x*y
end
p = (1.5,1.0,3.0,1.0); u0 = [1.0;1.0]
tspan = (0.0,10.0)
_h(p,t) = ones(2)
prob1 = DDEProblem(delay_lotka_volterra,u0,_h,tspan,p)
sol = solve(prob1)
plot(sol)

In [None]:
ddedata = Array(solve(prob1,saveat=0.1))

In [None]:
Turing.setadbackend(:forwarddiff)

@model function fitlv(data, ::Type{T}=Vector{Float64}) where {T}
    σ ~ InverseGamma(2, 3)
    α ~ Truncated(Normal(1.5,0.5),0.5,2.5)
    β ~ Truncated(Normal(1.2,0.5),0,2)
    γ ~ Truncated(Normal(3.0,0.5),1,4)
    δ ~ Truncated(Normal(1.0,0.5),0,2)

    p = [α,β,γ,δ]
    prob = DDEProblem(delay_lotka_volterra,u0,_h,tspan,p)
    predicted = solve(prob,saveat=0.01)

    for i = 1:length(predicted)
        data[:,i] ~ MvNormal(predicted[i], σ[1]*ones(length(predicted[i])))
    end
end;

model = fitlv(ddedata)
chain = sample(model, NUTS(.65),30)