In [22]:
# Imports
import DifferentialEquations.SciMLBase: AbstractDEProblem
import SciMLSensitivity: ForwardDiffSensitivity
import LinearAlgebra: I, LowerTriangular, logdet
import Random
import CSV
import Lux

include("lib/population.jl");
include("lib/callbacks.jl");
include("lib/objectives.jl");
include("lib/compartment_models.jl");

using Bijectors
using DataFrames
using AbstractGPs
using ApproximateGPs
using DifferentialEquations
;

In [23]:
# Get warfarin data -> Population
df = DataFrame(CSV.File("../data/warfarin.csv"))
df_group = groupby(df, :ID)

indvs = Vector{AbstractIndividual}(undef, length(df_group))
for (i, group) in enumerate(df_group)
    x = Vector{Float32}(group[1, [:WEIGHT]])
    ty = group[(group.DVID .== 1) .& (group.MDV .== 0), [:TIME, :DV]]
    𝐈 = Matrix{Float32}(group[group.MDV .== 1, [:TIME, :DOSE, :RATE, :DURATION]])
    callback = generate_dosing_callback(𝐈)
    indvs[i] = Individual(x, Float32.(ty.TIME), Float32.(ty.DV), callback; id = group.ID[1])
end
population = Population(indvs);
println("Done!")

Done!


In [27]:
# Functions
softplus(x::T) where {T<:Real} = log(exp(x) + one(T));
softplus_inv(x::T) where {T<:Real} = log(exp(x) - one(T));

function build_SVGP(p::NamedTuple; ϵ = 1e-6) 
    kernel = p.hyper[1] * (SqExponentialKernel() ∘ ScaleTransform(1 / p.hyper[2]))
    f = LatentGP(GP(ConstMean(p.hyper[3]), kernel), GaussianLikelihood(), ϵ)
    q = MvNormal(p.m, p.A)
    fz = f(p.z).fx
    return SparseVariationalApproximation(fz, q)
end

function constrain_(p_)
    sigma = softplus.(p_.sigma)
    Lc = inverse(Bijectors.VecCholeskyBijector(:L))(p_.corr).L
    L = sigma .* Lc
    A = L * L'
    return (hyper = softplus.(p_.hyper), z = p_.z, m = p_.m, A = A, L = L)
end

function _elbo(obj, prob, population, ps_::AbstractVector{<:NamedTuple}, sigma; N::Int = 100)
    ps = constrain_.(ps_)
    p = (error = (sigma = softplus.(sigma),),)
    gps = build_SVGP.(ps)
    f_posts = posterior.(gps)
    q_f = reduce(hcat, marginals.([f_post(population.x[1, :] ./ 200.f0) for f_post in f_posts]))
    f_μ = mean.(q_f)'
    f_σ = std.(q_f)'

    LL = zero(eltype(f_μ))
    for i in 1:N
        zs = vcat(softplus.(f_μ + f_σ .* randn(eltype(f_μ), size(f_μ, 1))), zeros(eltype(f_μ), 1, size(f_μ, 2)))
        ŷ = forward_adjoint_(prob, population, zs)
        σ² = variance(obj.error, p, ŷ)
        LL += ll(Normal, ŷ, σ², population.y)
    end
    return LL / N - sum(prior_kl_.(ps))
end

prior_kl_(p) = (sum(p.L .^ 2) + p.m'p.m - length(p.m) - logdet(p.A)) / 2

function forward_(problem::AbstractDEProblem, individual::AbstractIndividual, zᵢ::AbstractVecOrMat; get_dv::Bool=false, sensealg=nothing, full::Bool=false, interpolate::Bool=false, saveat_ = is_timevariable(individual) ? individual.t.y : individual.t)
    u0 = isempty(individual.initial) ? problem.u0 : individual.initial
    saveat = interpolate ? empty(saveat_) : saveat_
    save_idxs = full ? (1:length(u0)) : 2
    prob = remake(problem, u0 = u0, tspan = (problem.tspan[1], maximum(saveat_)), p = zᵢ)
    interpolate && (individual.callback.save_positions .= 1)
    sol = solve(prob, Tsit5(),
        save_idxs = save_idxs, saveat = saveat, callback=individual.callback, 
        tstops=individual.callback.condition.times, sensealg=sensealg
    )
    interpolate && (individual.callback.save_positions .= 0)
    return get_dv ? sol[2, :] : sol
end

forward_adjoint_(problem::AbstractDEProblem, population::Population, zs::AbstractMatrix) = forward_.((problem,), population, eachcol(zs); full=true, get_dv=true, sensealg=ForwardDiffSensitivity(;convert_tspan=true))
;

In [30]:
# Model setup
prob = ODEProblem(one_comp_abs!, zeros(Float32, 2), (-0.1f0, 144.f0), Float32[])

M = 6

p1 = (
    z = collect(range(0, 1, M)), 
    m = zeros(M), 
    sigma = softplus_inv.(fill(0.1, M)), 
    corr=Bijectors.VecCholeskyBijector(:L)(collect(I(M))), 
    hyper = softplus_inv.([0.1, 0.1, 0.1])
)

p2 = (
    z = collect(range(0, 1, M)), 
    m = zeros(M), 
    sigma = softplus_inv.(fill(0.1, M)), 
    corr=Bijectors.VecCholeskyBijector(:L)(collect(I(M))), 
    hyper = softplus_inv.([0.1, 0.1, 0.1])
)

p3 = (
    z = collect(range(0, 1, M)), 
    m = zeros(M), 
    sigma = softplus_inv.(fill(0.1, M)), 
    corr=Bijectors.VecCholeskyBijector(:L)(collect(I(M))), 
    hyper = softplus_inv.([0.1, 0.1, 0.1])
)



obj = LogLikelihood(Combined())
_elbo(obj, prob, population, [p1, p2, p3], softplus_inv.([0.1, 0.1]))

ps = [p1, p2]
p_sigma_init = softplus_inv.([0.1, 0.1])

opt = Optimisers.ADAM(0.1)
opt_state = Optimisers.setup(opt, ps)
opt2 = Optimisers.ADAM(0.1)
opt_state2 = Optimisers.setup(opt2, p_sigma_init)

for epoch in 1:200
    loss, back = Zygote.pullback((ps, p) -> -_elbo(obj, prob, population, ps, p; N=epoch <= 50 ? 10 : 50), ps, p_sigma_init)
    grad = back(1)
    println("Epoch $epoch: loss = $loss")
    opt_state, ps = Optimisers.update(opt_state, ps, first(grad))
    opt_state2, p_sigma_init = Optimisers.update(opt_state2, p_sigma_init, last(grad))
end

gps_opt = build_SVGP.(constrain_.(ps))

f_samples_1 = softplus.(rand(posterior(gps_opt[1])(0:0.01:1, 1e-6), 10_000))
qs1 = reduce(hcat, [quantile(f_samples_1[i, :], [0.05, 0.95]) for i in 1:size(f_samples_1, 1)])
mean1 = mean(f_samples_1, dims=2)[:, 1]
Plots.plot(0:0.01:1, mean1, ribbon = (mean1 - qs1[1, :], qs1[2, :] - mean1), label=nothing, color=:black, linewidth=1.4, fillalpha=0.15)

f_samples_2 = softplus.(rand(posterior(gps_opt[2])(0:0.01:1, 1e-6), 10_000))
qs2 = reduce(hcat, [quantile(f_samples_2[i, :], [0.05, 0.95]) for i in 1:size(f_samples_2, 1)])
mean2 = mean(f_samples_2, dims=2)[:, 1]
Plots.plot(0:0.01:1, mean2, ribbon = (mean2 - qs2[1, :], qs2[2, :] - mean2), xticks=(0:0.1:1, (0:0.1:1) .* 200), label=nothing, color=:black, linewidth=1.4, fillalpha=0.15)
;

ErrorException: type Array has no field beta