# Estimating the Parameters of Differential Equations using Turing.jl

## Load packages

In [1]:
using DifferentialEquations
using ParameterizedFunctions
using RecursiveArrayTools
using Turing
using Plots
srand(31415926) # fix random seed just for demo purpose
;

[Turing]: AD chunk size is set as 40




## Define  a DE problem using DifferentialEquations.jl

Here `a` is the parameter to estimate and `b`, `c` and `d` are fixed, which can be seen from that `a` uses `=>` and others use `=` for value assignment.

> The Lotka–Volterra equations, also known as the predator–prey equations, are a pair of first-order, nonlinear, differential equations frequently used to describe the dynamics of biological systems in which two species interact, one as a predator and the other as prey. - Wikipedia

In [2]:
f = @ode_def_nohes LotkaVolterraTest begin
    dx = a*x - b*x*y
    dy = -c*y + d*x*y
end a=>1.5 b=1.0 c=3.0 d=1.0
;

Set initial states and span

In [3]:
u0 = [1.0;1.0]
tspan = (0.0,10.0)
;

## Generate data using `a=1.5`

In [4]:
prob = ODEProblem(f,u0,tspan)
sol = solve(prob,Tsit5())
t = collect(linspace(0,10,40)) # I reduce the number of data so that I can see error bars from the 2nd prediction way
sig = 0.49  # nosie level
data = convert(Array, VectorOfArray([(sol(t[i]) + sig*randn(2)) for i in 1:length(t)]))
;

In [5]:
plot_data() = begin
    scatter(t, data[1,:], lab="#prey (data)")
    scatter!(t, data[2,:], lab="#predator (data)")
end

plot_data()

Helper function for update parameter

In [6]:
function problem_new_parameters(prob::ODEProblem,p)
  f = (t,u,du) -> prob.f(t,u,p,du)
  uEltype = eltype(p)
  u0 = [uEltype(prob.u0[i]) for i in 1:length(prob.u0)]
  tspan = (uEltype(prob.tspan[1]),uEltype(prob.tspan[2]))
  ODEProblem(f,u0,tspan)
end
;

## Define a Bayes model to fit parameters

$$ a \sim \text{Truncated}(\text{Normal}(1.5, 1), 0.5, 2.5) $$

$$ \sigma \sim \text{Inverse-Gamma}(2, 3) $$

$$ \text{Compute } sol_a $$

$$ data \sim \text{MvNormal}(sol_a, \sigma I) $$

**Note**: the use of truncated Gaussian for $a$ is just from my trails - I experience some bugs when $a<0$ and when $a$ is large.

In [7]:
@model bfit(x) = begin
    # Define prior
    a ~ Truncated(Normal(1.5, 1), 0.5, 2.5)  # DE param
    σ ~ InverseGamma(2, 3)                   # data noise
    
    # Update solver
    p_tmp = problem_new_parameters(prob, a); sol_tmp = solve(p_tmp,Tsit5())
    
    # Observe data
    # Here you can do a lot of ways to write the observation
    # and a lot of possible optimization to make the program faster
    # E.g. naively
    
    for i = 1:length(t)
        res = sol_tmp(t[i])
        x[:,i] ~ MvNormal(res, σ*ones(2))
    end
end

bfit (generic function with 2 methods)

## Run sampler

In [8]:
chn = sample(bfit(data), HMC(500, 0.02, 4))

[Turing]:  Assume - `a` is a parameter
  in @~(::Any, ::Any) at compiler.jl:76
[Turing]:  Assume - `σ` is a parameter
  in @~(::Any, ::Any) at compiler.jl:76
[Turing]:  Observe - `x` is an observation
  in @~(::Any, ::Any) at compiler.jl:57


[HMC] Sampling...  0%  ETA: 5:38:09[1m[34m
  ϵ:         0.02[0m[1m[34m
  α:         1.0[0m[1m[34m
[HMC] Sampling... 16%  ETA: 0:03:46[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         0.9023994677893477[0m[1m[34m
[HMC] Sampling... 32%  ETA: 0:01:33[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         2.658201577940882e-100[0m[1m[34m
[HMC] Sampling... 49%  ETA: 0:00:46[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         5.866280838468989e-5[0m[1m[34m
[HMC] Sampling... 65%  ETA: 0:00:24[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         1.0[0m[1m[34m
[HMC] Sampling... 81%  ETA: 0:00:11[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         0.9985106597954047[0m[1m[34m
[HMC] Sampling... 91%  ETA: 0:00:05[1m[34m[A[1G[K[A
  ϵ:         0.02[0m[1m[34m
  α:         5.141659514483981e-5[0m[1m[34m
  pre_cond:  [1.0,1.0][0m

[HMC] Finished with
  Running time        = 47.387767281999935;
  Accept rate         = 0.552;
  #lf / sample        = 3.992;
  #evals / sample     = 3.996;
  pre-cond. diag mat  = [1.0,1.0].


[1G[K[A[1G[K[A[1G[K[A[HMC] Sampling...100% Time: 0:00:48


Object of type "Turing.Chain"

Iterations = 1:500
Thinning interval = 1
Chains = 1
Samples per chain = 500

[-486.909 0.0 … 2.30228 0.02; -218.738 4.0 … 2.31922 0.02; … ; -55.0666 4.0 … 1.50907 0.02; -55.0666 4.0 … 1.50907 0.02]

## Make prediction 1: averaging samples as point estimate

In [9]:
a_avg = mean(chn[:a])
σ_avg = mean(chn[:σ])
println("E[a] = $a_avg, E[σ] = $σ_avg")
println("a.diff = $(a_avg - 1.5), σ.diff = $(σ_avg - 0.49)")

E[a] = 1.609196623975551, E[σ] = 0.9155962379216105
a.diff = 0.10919662397555108, σ.diff = 0.4255962379216105


In [10]:
tp = collect(linspace(0,10,200)) # t for plot
;

In [11]:
p_avg = problem_new_parameters(prob, a_avg); sol_avg = solve(p_avg,Tsit5())
pred_avg = convert(Array, VectorOfArray([sol_avg(tp[i]) for i in 1:length(tp)]))
plot_data()
plot!(tp, pred_avg[1,:], lab="#prey (prediction)")
plot!(tp, pred_avg[2,:], lab="#predator (prediction)")

## Make prediction 2: the Bayesian way

In [12]:
all_a = chn[:a]
p1 = []
p2 = []
for i = 1:10:length(all_a)    # not taking each sample as that's too expensive
    p_bayes = problem_new_parameters(prob, all_a[i]); sol_bayes = solve(p_bayes,Tsit5())
    pred_bayes = convert(Array, VectorOfArray([sol_avg(tp[i]) for i in 1:length(tp)]))
    push!(p1, pred_bayes[1,:])
    push!(p2, pred_bayes[2,:])
end
p1 = convert(Array,VectorOfArray(p1))
p2 = convert(Array,VectorOfArray(p2))
p1 = Array{Float64,2}(p1)
p2 = Array{Float64,2}(p2)
;

### Compute mean and 3 std error bar

In [13]:
m1 = vec(mean(p1, 2))
m2 = vec(mean(p2, 2))
e1 = 3*vec(sqrt(var(p1,2)))
e2 = 3*vec(sqrt(var(p2,2)))
;

In [14]:
plot_data()
plot!(tp, m1, ribbon=e1, lab="#prey (prediction)")
plot!(tp, m2, ribbon=e2, lab="#predator (prediction)")
# plot!(tp, m1, ribbon=e1*1e14, lab="#prey (prediction)")
# plot!(tp, m2, ribbon=e2*1e14, lab="#predator (prediction)")

**Note**: the prediction variance seems to be very small; can't even see from the plot

## Also check the trace and density plot of a and $\sigma$

In [15]:
samples_a = chn[:a]
samples_σ = chn[:σ]
;

In [16]:
plot([samples_a, samples_a], layout=@layout([a; b]),
      seriestype=[:line :hist], nbins=100,
      labels=["a (samples)", "a (ground true)"])
plot!([1.5, 1.5], layout=@layout([a; b]),  
       seriestype = [:hline, :vline],
       labels=["a (samples)", "a (ground true)"])

In [17]:
plot([samples_σ, samples_σ], layout=@layout([a; b]),
      seriestype=[:line :hist], nbins=100,
      labels=["sigma (samples)", "sigma (ground true)"])
plot!([0.49, 0.49], layout=@layout([a; b]),  
       seriestype = [:hline, :vline],
       labels=["sigma (samples)", "sigma (ground true)"])