In [32]:
using DifferentialEquations

phi(a, gamma_a) = a ./ (a + gamma_a)

function f(du,u,p0,t)
    B, E, M, a, h, p = u
    betaa, betab, betaE1, betaE2, betah1, betah2,betah3, betaM1, betaM2, betap, gammaa, gammaB, gammah, gammap, muaE, muaM, muhM, mupB, mupE, q = p0

    du[1]= betab.*phi(p,gammap).*B - q.*B
    du[2] = (betaE1.*phi(a, gammaa) + betaE2.*(1-phi(B, gammaB)).*phi(p, gammap)).*E - q.*E
    du[3] = (betaM1.*phi(a, gammaa) + betaM2.*phi(h, gammah)).*M - q.*M

    du[4] = betaa.*phi(p, gammap).*B - q.*a - (muaE.*E + muaM.*M).*phi(a, gammaa)
    du[5] = betah1.*phi(a, gammaa).*E + betah2.*phi(p, gammap).*B + betah3.*(1-phi(B,gammaB)).*phi(p,gammap)*E- q.*h- muhM.*phi(h, gammah).*M
    du[6] = betap.*q.*(cos.(t)+1).^3 - q.*p - (mupB.*B + mupE.*E).*phi(p, gammap)
end

"""
Ci(i, A, B)

Creates a new matrix C by replacing the ith column of matrix B 
with the ith column of matrix A.

 i - column number to be changed in matrix B
 A - matrix containing half of the pseudo-random samples (# params, num_samples)
 B - matrix containing the other half of the pseudo-random (# params, num_samples)
     samples
"""

function Ci(i, A, B) 
    C = copy(B) # super necessary to copy B or else it will get overwritten
    C[i,:] = A[i,:]
    return C
end


function loss(truth, sol, tpts, max_t)
    """ loss computes MAPE (averaged across timepoints length(sol) - tpts  to length(sol)) and ground truth weighted by a weight vector
    truth (Vector): vector containing the ground truth
    pred (Matrix): matrix with shape (variables, timepoints)
    t_start (integer): first timepoint to pull to compute an average across time
    t_end (integer): last timepoint to pull to compute an average across time
    """
    
    # average the losses from the 3 species since we only observed their summed biomass
    s = abs((sum(truth[1:3]) - sum(center(sol, tpts,max_t)[1:3,:])) ./sum(truth[1:3]))
    
    # compute the remaining losses for p, h, and a and return the average
    return (s + sum(broadcast(abs, (truth[4:6]-center(sol, tpts,max_t)[4:6,:]) ./ truth[4:6] )))/4
end

"""
sample: takes in a matrix of parameter values and outputs the correponding output

p: shape (# parameters, # samples)

Output: 
"""
function sample(p)
    # p has shape (params, # samples)
    l = size(p)[2]
    u0 = [0.0004706; 0.0004706; 0.0004706; 9.7079; 7.9551; 32.061]
    max_t = 250.0
    tspan = (0.0,max_t)
    
    # create the ODEProblem for each parameter sample
    prob = [ODEProblem(f,u0,tspan,p[:,n]) for n in 1:l]
    
    # solve the ODEProblem for each parameter sample
    #sol = [solve(prob[n], Tsit5(),reltol=1e-6) for n in 1:l]
    sol = [solve(prob[n],Rodas4()) for n in 1:l]

    # pull only the last 100 timepoints of the ODE solution to compute metrics for
    tpts = 50
    
    #losses = [loss(u0, sol[n], tpts, max_t) for n in 1:l]
    
    return [mean(sol[n], dims = 2) for n in 1:l] #[center(sol[n], tpts, max_t) for n in 1:l] , losses
end

center(sol, tpts, max_t) = (maximum(sol((max_t-tpts):max_t), dims=2) + minimum(sol((max_t-tpts):max_t), dims=2))/2

center (generic function with 1 method)

In [36]:
using QuasiMonteCarlo
using Statistics


In [37]:
# specify the upper and lower limits of the parameter ranges
upper_b =  [1e6, 10,10,10,1000,1e5,1e5,10,10,1e5,1e4,1e4,1e4,1e4,1e6,1e6,1e5,1e8,1e8,10]
lower_b = repeat([0],length(upper_b))

n_samples = 5#12

# sample A and B
sampler =LatticeRuleSample() # generates a sample -- seed is not fixed so we will get different draws every time this is called
A = QuasiMonteCarlo.sample(n_samples, lower_b, upper_b, sampler)
B = QuasiMonteCarlo.sample(n_samples, lower_b, upper_b, sampler) # A, B are shapes (# params, num_samples)

20×5 Matrix{Float64}:
     4.29496e5       9.29496e5  …       1.79496e5       5.54496e5
     1.03059         6.03059            3.53059         4.78059
     8.01994         3.01994            0.519936        1.76994
     3.54258         8.54258            1.04258         4.79258
   211.45          711.45             961.45          336.45
 22394.1         72394.1        …   47394.1          9894.06
 31455.3         81455.3             6455.27        43955.3
     7.61636         2.61636            0.116359        1.36636
     2.93491         7.93491            0.434912        9.18491
 49716.6         99716.6            74716.6         37216.6
  2456.6          7456.6        …    4956.6          1206.6
   768.12         5768.12            3268.12         9518.12
  7504.36         2504.36               4.36454      6254.36
  9597.57         4597.57            2097.57         3347.57
 64799.0        564799.0           814799.0        189799.0
 39745.0        539745.0        …  289745.0    

In [38]:
using Dates
println(Dates.format(now(), "HH:MM:SS") )
sample(A)
sample(B)
[sample(B) for n in 1:20]
println(Dates.format(now(), "HH:MM:SS") )

19:02:28
19:02:30


KeyError: KeyError: key "usage_request" not found

KeyError: KeyError: key "usage_request" not found

In [2]:
using JLD2
# load in estimated parameters 
NM_params = load("/pscratch/sd/m/maadrian/gut_microbiota/params.jld2")["data"]

# specify the upper and lower limits of the parameter ranges
upper_b =  NM_params*5
lower_b = NM_params/2

20-element Vector{Float64}:
 160783.43805659568
      0.6387385039057071
      0.42447189270225255
      0.2616768347915121
    205.99091389998645
  16285.778663361489
   5014.528799758551
      0.37490528008667245
      0.22820391005218382
    532.4383163051334
    119.21133026293779
      5.616199052704344
     66.31328793500653
    205.98386060879025
  24823.981706455463
  19520.284844456735
   1654.8224174516608
 170649.49513407532
      2.6215859166121334e6
      0.027145566895013365

In [3]:
upper_b

20-element Vector{Float64}:
      1.6078343805659567e6
      6.387385039057071
      4.244718927022525
      2.616768347915121
   2059.9091389998644
 162857.7866336149
  50145.28799758551
      3.7490528008667248
      2.282039100521838
   5324.383163051334
   1192.1133026293778
     56.16199052704344
    663.1328793500653
   2059.8386060879025
 248239.81706455463
 195202.84844456735
  16548.22417451661
      1.7064949513407531e6
      2.6215859166121334e7
      0.2714556689501336