In [None]:
using DataFrames
using CSV
using Plots
using Plotly
plotly() # select backend for plots

Read experimental data with time, cell density, DIN, and N quota

In [None]:
liefer = CSV.File("liefer-growth-data.csv") |> DataFrame ; 

In [None]:
# gr()
# @df liefer Plots.scatter(:Date, :DIN)

In [None]:
# Plots.scatter(liefer[!, :Date], liefer[!, :DIN])

Assemble time, DIN, N, cell number data (t, R, Q, X) for differential equation.
For now, take only one replicate and one species

In [None]:
ss = filter( [:"Species", :"Replicate", :"Cell Density"] => (x,y,z) -> x == "Thalassiosira pseudonana" && y == "A" && !ismissing(z), liefer)

In [None]:
t = ss."Days in N-free Media"
R = ss.DIN
Q = ss.N
X = ss."Cell Density"
[t R Q X] # units: d, µmol/L, pg/cell, cells/mL

Replace missing values in R with 0s

In [None]:
R = map(x -> ismissing(x) ? 0 : x, R)

Fix units for R, Q, X. Make R into pg/mL.
µmol/L * (1 L / 1000 mL) * (14 g / mol) * (mol/10^6 µmol) * (10^12 pg / g) = 10^12 * 10^(-6) * 10^(-3) * 14 pg/mL

In [None]:
RpgmL = R .* (10^3 * 14)

There is a diution that happens after the first time step.

In [None]:
liefer[4,:"Dilution Factor"]
X[1] = X[1]*liefer[4,:"Dilution Factor"]

Check mass balance. R + QX should be constant. In fact, it looks like about 15-20% of mass is lost between time 0 and time 1. Reduce R[t0] to balance total mass at time 1.

In [None]:
Plots.scatter(t, RpgmL .+ Q .* X)
Plots.scatter!(t, RpgmL)
Plots.scatter!(t, Q .* X)

In [None]:
RpgmL[1] = RpgmL[2] + Q[2]*X[2] - Q[1]*X[1]

Solve Droop differential equation


In [None]:
using DifferentialEquations

In [None]:
function droop!(du, u, p, t)
  R, Q, X = u
  Km, Vmax, Qmin, muMax = p
  d = 0.0
  R0 = 0.0
  rho = Vmax * R / (Km + R)
  mu = muMax * (1 - Qmin/Q)
  du[1] = dRdt = d*(R0 - R) - rho*X
  du[2] = dQdt = rho - mu*Q
  du[3] = dXdt = (mu - d)*X
end

# Initial condition
u0 = [1.0, 1.0, 1.0]

# Simulation interval and intermediary points
tspan = (0.0, 10.0)
tsteps = 0.0:0.1:10.0

p = [0.1, 2.0, 1.0, 0.8]

# Setup the ODE problem, then solve
prob = ODEProblem(droop!, u0, tspan, p)
sol = solve(prob, Tsit5())


In [None]:
# using Plotly
plotly()
Plots.plot(sol)

Find parameters that best match data in t, R, Q, X. Ignore R? Or make missing == 0

Use times t. Ignore R.

In [None]:
using DifferentialEquations, Flux, DiffEqFlux, Optim, DiffEqSensitivity

In [None]:
p = [100.0, 0.1, 1.0, 0.8]
my_mean = x -> sum(x)/length(x)
my_sd = x -> sqrt(sum((x .- my_mean(x)).^2)/(length(x)-1))
prob = ODEProblem(droop!, [RpgmL[1], Q[1], X[1]], tspan, p)

function loss(p)
  sol = solve(prob, Tsit5(), p=p, saveat = t)
  deltaQ = [ sol.u[i][2] - Q[i] for i in 1:length(t) ]
  # deltaX = [ log(sol.u[i][3]) - log(X[i]) for i in 1:length(t) ]
  # loss = sum(abs2, deltaX ./ sd(log.(X))) + sum(abs2, deltaQ ./ sd(Q))
  deltaX = [ sol.u[i][3] - X[i] for i in 1:length(t) ]
  loss = sum(abs2, deltaX ./ my_sd(X)) + sum(abs2, deltaQ ./ my_sd(Q))
  return loss, sol
end

callback = function (p, l, pred)
  display(l)
  # plt = plot(pred, ylim = (0, 6))
  # display(plt)
  # Tell sciml_train to not halt the optimization. If return true, then
  # optimization stops.
  return false
end


Test loss(p) before passing on to optimizer

In [None]:
loss(p)

In [None]:
result_ode = DiffEqFlux.sciml_train(loss, p,
                                    ADAM(0.1),
                                    cb = callback,
                                    maxiters =100)

In [None]:
result_ode
sol = solve(prob, Tsit5(), p = result_ode.u, saveat = t)

In [None]:
pR = Plots.plot(sol, vars = [(0,1)], ylabel="R", xlabel = "")
pR = Plots.scatter!(t, RpgmL)
pQ = Plots.plot(sol, vars = [(0,2)], ylabel="Q", xlabel = "")
pQ = Plots.scatter!(t, Q)
pX = Plots.plot(sol, vars = [ ((t,X)-> (t, log.(X)), 0, 3)], ylabel="log X")
pX = Plots.scatter!(t, log.(X))
Plots.plot(pR, pQ, pX, layout = (3,1))

Why is this fit so bad? Perhaps one reason is that the initial conditions are not well known.

Include them in the parameters to be estimated and a penalty for them being too wrong.

In [None]:
p = [RpgmL[1], Q[1], X[1], 100.0, 1.0, 1.0, 0.8]

function loss(p)
  prob = ODEProblem(droop!, p[1:3], tspan, p[4:end])
  sol = solve(prob, Rosenbrock23(), p=p[4:end], saveat = t)
     # Integrators: AutoTsit5(Rosenbrock23()), Rosenbrock23(), Tsit5()
  # deltaR = [ sol.u[i][1] - RpgmL[i] for i in 1:3 ]  # Including R in loss function makes results much worse
  deltaQ = [ sol.u[i][2] - Q[i] for i in 1:(length(t)-1) ]
  deltaX = [ log(sol.u[i][3]) - log(X[i]) for i in 1:(length(t)-1) ]
  loss = sum(abs2, deltaX ./ my_sd(log.(X))) + sum(abs2, deltaQ ./ my_sd(Q))  # + sum(abs2, deltaR ./ sd(RpgmL[1:3]))
  # deltaX = [ sol.u[i][3] - X[i] for i in 1:length(t) ]
  # loss = sum(abs2, deltaX ./ sd(X)) + sum(abs2, deltaQ ./ sd(Q)) # + sum(abs2, deltaR ./ sd(RpgmL[1:2]))
  return loss, sol
end

In [None]:
loss(p)

In [None]:
result_ode = DiffEqFlux.sciml_train(loss, p,
                                    ADAM(0.1),
                                    cb = callback,
                                    maxiters =100)

In [None]:
[result_ode.u[1:3], [RpgmL[1], Q[1], X[1]]]

In [None]:
result_ode
prob = ODEProblem(droop!, result_ode.u[1:3], tspan, result_ode.u[4:end])
sol = solve(prob, Rosenbrock23(), p = result_ode.u[4:end]) # , saveat = t)

In [None]:
pR = Plots.plot(sol, vars = [(0,1)], ylabel="R", xlabel = "")
pR = Plots.scatter!(t, RpgmL)
pQ = Plots.plot(sol, vars = [(0,2)], ylabel="Q", xlabel = "")
pQ = Plots.scatter!(t, Q)
pX = Plots.plot(sol, vars = [ ((t,X)-> (t, log.(X)), 0, 3)], ylabel="log X")
pX = Plots.scatter!(t, log.(X))
pM = Plots.plot(sol, vars = [ ((t,R,Q,X) -> (t, R + Q*X), 0, 1, 2, 3)], ylabel = "total Mass")
pM = Plots.scatter!(t, RpgmL .+ Q .* X)
Plots.plot(pR, pQ, pX, pM, layout = (2,2))

Changing the integrator to a stiff method and not fitting to R helps.

Gather all three replicates together. Estimate time 0 values for all three separately. Use the same Qmin, Km, Vmax, mumax.

In [None]:
t1, RpgmL1, Q1, X1 = t, RpgmL, Q, X

ss = filter( [:"Species", :"Replicate", :"Cell Density"] => (x,y,z) -> x == "Thalassiosira pseudonana" && y == "B" && !ismissing(z), liefer)
t = ss."Days in N-free Media"
R = ss.DIN
Q = ss.N
X = ss."Cell Density"
R = map(x -> ismissing(x) ? 0 : x, R)
RpgmL = R .* (10^3 * 14)
liefer[4,:"Dilution Factor"]
X[1] = X[1]*liefer[4,:"Dilution Factor"]
p2 = Plots.scatter(t, RpgmL .+ Q .* X)
p2 = Plots.scatter!(t, RpgmL)
p2 = Plots.scatter!(t, Q .* X)
RpgmL[1] = RpgmL[2] + Q[2]*X[2] - Q[1]*X[1]
t2, RpgmL2, Q2, X2 = t, RpgmL, Q, X

ss = filter( [:"Species", :"Replicate", :"Cell Density"] => (x,y,z) -> x == "Thalassiosira pseudonana" && y == "C" && !ismissing(z), liefer)
t = ss."Days in N-free Media"
R = ss.DIN
Q = ss.N
X = ss."Cell Density"
R = map(x -> ismissing(x) ? 0 : x, R)
RpgmL = R .* (10^3 * 14)
liefer[4,:"Dilution Factor"]
X[1] = X[1]*liefer[4,:"Dilution Factor"]
p3 = Plots.scatter(t, RpgmL .+ Q .* X)
p3 = Plots.scatter!(t, RpgmL)
p3 = Plots.scatter!(t, Q .* X)
RpgmL[1] = RpgmL[2] + Q[2]*X[2] - Q[1]*X[1]
t3, RpgmL3, Q3, X3 = t, RpgmL, Q, X

Plots.plot(p2, p3)

In [None]:
p = [RpgmL1[1], Q1[1], X1[1], RpgmL1[1], Q2[1], X2[1], RpgmL3[1], Q3[1], X3[1], 100.0, 1.0, 1.0, 0.8]

function loss(p)
  prob1   = ODEProblem(droop!, p[1:3], tspan, p[10:end]) 
  sol1    = solve(prob1, Rosenbrock23(), p=p[10:end], saveat = t1)
  deltaQ1 = [ sol1.u[i][2] - Q1[i] for i in 1:(length(t)-1) ]
  deltaX1 = [ log(sol1.u[i][3]) - log(X1[i]) for i in 1:(length(t1)-1) ]
  prob2   = ODEProblem(droop!, p[4:6], tspan, p[10:end])
  sol2    = solve(prob2, Rosenbrock23(), p=p[10:end], saveat = t2)
  deltaQ2 = [ sol2.u[i][2] - Q2[i] for i in 1:(length(t)-1) ]
  deltaX2 = [ log(sol2.u[i][3]) - log(X2[i]) for i in 1:(length(t2)-1) ]
  prob3   = ODEProblem(droop!, p[7:9], tspan, p[10:end])
  sol3    = solve(prob3, Rosenbrock23(), p=p[10:end], saveat = t3)
  deltaQ3 = [ sol3.u[i][2] - Q3[i] for i in 1:(length(t)-1) ]
  deltaX3 = [ log(sol3.u[i][3]) - log(X3[i]) for i in 1:(length(t3)-1) ]
  loss = sum(abs2, deltaX1 ./ my_sd(log.(X1))) + sum(abs2, deltaQ1 ./ my_sd(Q1)) + sum(abs2, deltaX2 ./ my_sd(log.(X2))) + sum(abs2, deltaQ2 ./ my_sd(Q2))  + sum(abs2, deltaX3 ./ my_sd(log.(X3))) + sum(abs2, deltaQ3 ./ my_sd(Q3))   # + sum(abs2, deltaR ./ sd(RpgmL[1:3]))
  # deltaX = [ sol.u[i][3] - X[i] for i in 1:length(t) ]
  # loss = sum(abs2, deltaX ./ sd(X)) + sum(abs2, deltaQ ./ sd(Q)) # + sum(abs2, deltaR ./ sd(RpgmL[1:2]))
  return loss, sol
end

In [None]:
loss(p)

In [None]:
result_ode = DiffEqFlux.sciml_train(loss, p,
                                    ADAM(0.1),
                                    cb = callback,
                                    maxiters =100)

In [None]:
pR = Plots.plot(sol, vars = [(0,1)], ylabel="R", xlabel = "")
pR = Plots.scatter!(t1, RpgmL1)
pR = Plots.scatter!(t2, RpgmL2)
pR = Plots.scatter!(t3, RpgmL3)
pQ = Plots.plot(sol, vars = [(0,2)], ylabel="Q", xlabel = "")
pQ = Plots.scatter!(t1, Q1)
pQ = Plots.scatter!(t2, Q2)
pQ = Plots.scatter!(t3, Q3)
pX = Plots.plot(sol, vars = [ ((t,X)-> (t, log.(X)), 0, 3)], ylabel="log X")
pX = Plots.scatter!(t1, log.(X1))
pX = Plots.scatter!(t2, log.(X2))
pX = Plots.scatter!(t3, log.(X3))
pM = Plots.plot(sol, vars = [ ((t,R,Q,X) -> (t, R + Q*X), 0, 1, 2, 3)], ylabel = "total Mass")
pM = Plots.scatter!(t1, RpgmL1 .+ Q1 .* X1)
pM = Plots.scatter!(t2, RpgmL2 .+ Q2 .* X2)
pM = Plots.scatter!(t3, RpgmL3 .+ Q3 .* X3)
Plots.plot(pR, pQ, pX, pM, layout = (2,2))

Try the Bayesian fitting method to try to get a distribution on solutions and parameters.

https://github.com/TuringLang/TuringTutorials/blob/master/10_diffeq.ipynb
https://turing.ml/dev/tutorials/10-bayesiandiffeq/

In [None]:
using Turing, Distributions, DifferentialEquations 

# Import MCMCChain, Plots, and StatsPlots for visualizations and diagnostics.
using MCMCChains, Plots, StatsPlots

# Set a seed for reproducibility.
using Random
Random.seed!(14);
using Logging
Logging.disable_logging(Logging.Warn)

In [None]:
Turing.setadbackend(:forwarddiff)

@model function fitDroop(t, R, Q, X, logX)
    σ1 ~ InverseGamma(2, 3) # ~ is the tilde character
    σ2 ~ InverseGamma(2, 3) # ~ is the tilde character
    Km ~ truncated(Normal(100,10),0,200)
    Vmax ~ truncated(Normal(1.2,0.5),0,3)
    Qmin ~ truncated(Normal(1.0,0.5),0,3)
    muMax ~ truncated(Normal(1.0,0.5),0,3)

    p = [ Km, Vmax, Qmin, muMax]

    # must define the problem with numeric values first, then update with distributions
    prob1 = ODEProblem(droop!, [RpgmL1[1], Q1[1], X1[1]], (0.0, 10.0), [200.0, 1.0, 1.0, 1.0])
    prob = remake(prob1, p=p)  # modifies the original problem

    predicted = solve(prob, Rosenbrock23(), saveat=t)
    
    for j = 1:7
        Q[j] ~ Normal(predicted[j][2], σ1)
        logX[j] ~ Normal(log.(predicted[j][3]), σ2)
    end
end


@model function fitDroop1(t, R, Q, X, logX)
    σ1 ~ InverseGamma(2, 3) # ~ is the tilde character
    # σ2 ~ InverseGamma(2, 3) 
    R0 ~ Normal(300000, 1000)
    Q0 ~ truncated(Normal(3, 1), 0, 10)
    X0 ~ Normal(65000,1000)
    Km ~ truncated(Normal(100,10),0,200)
    Vmax ~ truncated(Normal(1.2,0.5),0,3)
    Qmin ~ truncated(Normal(1.0,0.5),0,3)
    muMax ~ truncated(Normal(1.0,0.5),0,3)

    p = [ Km, Vmax, Qmin, muMax]

    # must define the problem with numeric values first, then update with distributions
    prob1 = ODEProblem(droop!, [RpgmL1[1], Q1[1], X1[1]], (0.0, 10.0), [200.0, 1.0, 1.0, 1.0])
    prob = remake(prob1, u0=[R0, Q0, X0], p=p)  # modifies the original problem  # fails ****

    # prob = ODEProblem(droop!, [R0, Q0, X0], (0,10), p)
    # prob = ODEProblem(droop!, [R[1], Q[1], exp(X[1])], (0.0, 10.0), p)
    predicted = solve(prob, Rosenbrock23(), saveat=t)
    
    for j = 1:7
        Q[j] ~ Normal(predicted[j][2], σ1)
        # logX[i] ~ Normal(predicted[i][3], σ2)
    end
end

In [None]:
# reminder of how this works...
prob = ODEProblem(droop!, [RpgmL1[1], Q1[1], X1[1]], (0.0, 10.0), [200.0, 1.0, 1.0, 1.0])
predicted = solve(prob, Rosenbrock23(), saveat=t)
predicted[7][2]

In [None]:
model = fitDroop(t, RpgmL1, Q1, X1, log.(X1))
# model = fitDroop1(t, RpgmL1, Q1, X1, log.(X1)) # fails ***
# chain = sample(model, NUTS(0.65), 100)  # 8:20
# This next command runs 3 independent chains without using multithreading. 
# chain = mapreduce(c -> sample(model, NUTS(.65), 1000), chainscat, 1:3) # takes a few minutes 7:43

In [None]:
# Threads.nthreads()  # Must start julia with --threads 4 (or other value)
# chain2 = sample(model, NUTS(.65), MCMCThreads(), 1000, 4, progress=false) # takes 15 minutes or so. wild guess.
chain2 = sample(model, NUTS(.65), MCMCThreads(), 100, 4, progress=false) # not enough iterations; demo only


In [None]:
median(chain2[:muMax]), median(chain2[:Qmin]), median(chain2[:Km]), median(chain2[:Vmax])
# median.(chain2)

In [None]:
Plots.plot(chain2)

In [None]:
chain_array = Array(chain2);


In [None]:
sol2 = solve(remake(prob, 
        p = [median(chain2[:Km]), median(chain2[:Qmin]), median(chain2[:Vmax]), median(chain2[:muMax])]), 
        Rosenbrock23()); 

In [None]:
pl = Plots.scatter(t, RpgmL1);
for k in 1:300
    resol = solve(remake(prob,p=chain_array[rand(1:size(chain_array)[1]), 1:4]),Rosenbrock23()) 
    # Note that due to a bug in AxisArray, the variables from the chain will be returned always in
    # the order it is stored in the array, not by the specified order in the call - :α, :β, :γ, :δ
    plot!(resol, vars=(0,1), alpha=0.3, color = "#BBBBBB", legend = false, ylims=(0, Inf))
end
plot!(sol2, vars=(0,1), alpha=1, color = "#BB0000", legend = false, ylims=(0, Inf))
display(pl)


In [None]:
pl = Plots.scatter(t, Q1);
for k in 1:300
    resol = solve(remake(prob,p=chain_array[rand(1:size(chain_array)[1]), 1:4]),Rosenbrock23()) 
    # Note that due to a bug in AxisArray, the variables from the chain will be returned always in
    # the order it is stored in the array, not by the specified order in the call - :α, :β, :γ, :δ
    plot!(resol, vars=(0,2), alpha=0.31, color = "#BBBBBB", legend = false)
end
plot!(sol2, vars=(0,2), alpha=1, color = "#BB0000", legend = false)
display(pl)


In [None]:
pl = Plots.scatter(t, log.(X1));
for k in 1:300
    resol = solve(remake(prob,p=chain_array[rand(1:size(chain_array)[1]), 1:4]),Rosenbrock23()) 
    # Note that due to a bug in AxisArray, the variables from the chain will be returned always in
    # the order it is stored in the array, not by the specified order in the call - :α, :β, :γ, :δ
    plot!(resol, vars=((t,x) -> (t, log.(x)), 0,3), alpha=0.3, color = "#BBBBBB", legend = false)
end
plot!(sol2, vars=((t,x) -> (t, log.(x)), 0,3), alpha=1, color = "#BB0000", legend = false)
display(pl)

Revise to estimate initial conditions and use all three replicates.