Skip to content

Help fitting parameters of a simple SIR model through ODE simulation #2559

@TorkelE

Description

@TorkelE

I have a SIR model which I want to fit using Turing, however, I got problems with the parameters. Primarily, it seems that I have to provide a good initial guess, but it would be useful to know what to do if I do not know the true values.

First I generate some data to fit to (I also have a case with true data, but I am trying to make a somewhat mwe).

# Fetch packages
using Distributions
using Turing
using OrdinaryDiffEq
using StatsPlots
using Random
Random.seed!(14);

# Define SIR model.
function sir(du, u, p, t)
    S, I, R = u
    γ, ν = p
    infection = γ * S * I
    recovery = ν * I
    du[1] = -infection
    du[2] = infection - recovery
    du[3] = recovery
    return nothing
end

# Create ODEProblem for true model parameters.
u0_known = [999, 1, 0]
tend_known = 100.0
p_true = 0.0003, 0.1
oprob_true = ODEProblem(sir, u0_known, tend_known, p_true)

# Generate synthetic data (and plot it).
sol_true = solve(oprob_true)
measured_t = 5:5:100
measured_I = [rand(Normal(I, 0.2*I)) for I in sol_true(measured_t, idxs = 2).u]
plot(sol_true; label = "Vals (true)")
plot!(measured_t, measured_I; label = "Vals (measured)", seriestype = :scatter)

Image

Next, I create my Turing model (primarily based on https://turinglang.org/docs/tutorials/bayesian-differential-equations/).

# Creates Turing model.
@model function fit_sir(data, prob)
    # Sets prior distributions.
    γ ~ LogUniform(0.00001, 0.001)
    ν ~ LogUniform(0.01, 0.9)
    σI ~ LogUniform(0.1, 1)

    # Simulate the model.
    prob = remake(prob; p = [γ, ν])
    predicted = solve(prob; saveat = measured_t, verbose = false, maxiters = 10000)

    # If simulation was unsuccessful, the likelihood is -Inf.
    if !SciMLBase.successful_retcode(predicted)
        Turing.@addlogprob! -Inf
        return nothing
    end

    # Computes the likelihood of the observations.
    for i in eachindex(predicted)
        I = max(predicted[i][2], 0.0)
        data[i] ~ truncated(Normal(I, σI *I); lower = 0.0, upper = Inf)
    end

    return nothing
end

Fitting model without an initial guess

This simply gives me an error, and I am a bit uncertain how to proceed from here:

# Fit model to data (fails without an initial guess).
model = fit_sir(sol[2,:], oprob_true)
chain = sample(model, NUTS(), MCMCSerial(), 1000, 3; progress = false)
┌ Warning: failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword
└ @ Turing.Inference ~/.julia/packages/Turing/ocrZY/src/mcmc/hmc.jl:182
ERROR: failed to find valid initial parameters in 1000 tries. This may indicate an error with the model or AD backend; please open an issue at https://github.com/TuringLang/Turing.jl/issues
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] initialstep(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}, vi_original::DynamicPPL.VarInfo{…}; initial_params::Nothing, nadapts::Int64, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/ocrZY/src/mcmc/hmc.jl:185
  [3] step(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, spl::DynamicPPL.Sampler{…}; initial_params::Nothing, kwargs::@Kwargs{…})
    @ DynamicPPL ~/.julia/packages/DynamicPPL/RoIdx/src/sampler.jl:125
  [4] step
    @ ~/.julia/packages/DynamicPPL/RoIdx/src/sampler.jl:108 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/AbstractMCMC/64ptM/src/sample.jl:161 [inlined]
  [6] macro expansion
    @ ~/.julia/packages/AbstractMCMC/64ptM/src/logging.jl:16 [inlined]
  [7] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{…})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/64ptM/src/sample.jl:144
  [8] mcmcsample
    @ ~/.julia/packages/AbstractMCMC/64ptM/src/sample.jl:109 [inlined]
  [9] #sample#36
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/hmc.jl:113 [inlined]
 [10] sample
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/hmc.jl:82 [inlined]
 [11] sample_chain
    @ ~/.julia/packages/AbstractMCMC/64ptM/src/sample.jl:633 [inlined]
 [12] #4
    @ ./generator.jl:37 [inlined]
 [13] iterate
    @ ./generator.jl:48 [inlined]
 [14] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{…}}, Base.var"#4#5"{AbstractMCMC.var"#sample_chain#82"{…}}})
    @ Base ./array.jl:791
 [15] map
    @ ./abstractarray.jl:3495 [inlined]
 [16] mcmcsample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, ::MCMCSerial, N::Int64, nchains::Int64; progressname::String, initial_params::Nothing, initial_state::Nothing, kwargs::@Kwargs{…})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/64ptM/src/sample.jl:645
 [17] sample(rng::Random.TaskLocalRNG, model::DynamicPPL.Model{…}, sampler::DynamicPPL.Sampler{…}, ensemble::MCMCSerial, N::Int64, n_chains::Int64; chain_type::Type, progress::Bool, kwargs::@Kwargs{})
    @ Turing.Inference ~/.julia/packages/Turing/ocrZY/src/mcmc/Inference.jl:318
 [18] sample
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/Inference.jl:307 [inlined]
 [19] #sample#10
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/Inference.jl:304 [inlined]
 [20] sample
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/Inference.jl:293 [inlined]
 [21] #sample#9
    @ ~/.julia/packages/Turing/ocrZY/src/mcmc/Inference.jl:288 [inlined]

Fitting model with a perfect initial guess

This works fine (although it is not very interesting).

# Tries with initial parameters (that are the correct ones). Works fine.
initial_params = fill([0.0003, 0.1, 0.2], 3)
chain = sample(model, NUTS(), MCMCSerial(), 1000, 3; progress = false, initial_params)
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32

Fitting model without a quite good initial guess

This initially works, however, the plots look very weird (i.e. entirely flat and not at correct values).

# Tries with slightly wrong initial guesses. Plots looks weird.
initial_params = fill([0.0005, 0.05, 0.35], 3)
chain = sample(model, NUTS(), MCMCSerial(), 1000, 3; progress = false, initial_params)
plot(chain)
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32
┌ Info: Found initial step size
└   ϵ = 7.888609052210118e-32

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions