Skip to content

Commit

Permalink
Merge pull request #216 from SciML/compathelper/new_version/2021-06-0…
Browse files Browse the repository at this point in the history
…3-02-42-43-659-1397580942

CompatHelper: bump compat for "Turing" to "0.16"
  • Loading branch information
Vaibhavdixit02 committed Jul 12, 2021
2 parents 1b95523 + 9ed7f80 commit ea7fa7c
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 127 deletions.
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ authors = ["Vaibhavdixit02 <vaibhavyashdixit@gmail.com>"]
version = "2.25.0"

[deps]
ApproxBayes = "f5f396d3-230c-5e07-80e6-9fadf06146cc"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Expand All @@ -26,12 +25,12 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StanSample = "c1514b29-d3a0-5178-b312-660c88baa699"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TransformVariables = "84d833dd-6860-57f9-a1a7-6da5db126cff"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ApproxBayes = "0.3"
DiffEqBase = "6.36"
DiffResults = "0.0.4, 1.0"
Distances = "0.8, 0.9, 0.10"
Expand All @@ -46,14 +45,13 @@ Missings = "0.4, 1.0"
ModelingToolkit = "5.6"
Optim = "0.19, 0.20, 0.21, 0.22, 1.0"
PDMats = "0.9, 0.10, 0.11"
ParameterizedFunctions = "5"
Parameters = "0.12"
RecursiveArrayTools = "1,2"
Reexport = "0.2, 1.0"
Requires = "0.5, 1.0"
StructArrays = "0.4, 0.5"
TransformVariables = "0.3, 0.4"
Turing = "0.12, 0.13, 0.14, 0.15"
Turing = "0.12, 0.13, 0.14, 0.15, 0.16"
julia = "1.3"

[extras]
Expand Down
16 changes: 5 additions & 11 deletions src/DiffEqBayes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,23 @@ using DocStringExtensions
using DiffEqBase, Distributions, Turing, MacroTools
using RecursiveArrayTools, ModelingToolkit
using Parameters, Distributions, Optim, Requires
using Distances, ApproxBayes, DocStringExtensions, Random
using Distances, DocStringExtensions, Random, StanSample

STANDARD_PROB_GENERATOR(prob,p) = remake(prob;u0=eltype(p).(prob.u0),p=p)
STANDARD_PROB_GENERATOR(prob::EnsembleProblem,p) = EnsembleProblem(remake(prob.prob;u0=eltype(p).(prob.prob.u0),p=p))

include("turing_inference.jl")
include("abc_inference.jl")
# include("abc_inference.jl")
include("stan_string.jl")
include("stan_inference.jl")

function __init__()
@require CmdStan="593b3428-ca2f-500c-ae53-031589ec8ddd" begin
using .CmdStan
include("stan_inference.jl")
include("stan_string.jl")
export stan_inference, stan_string
end

@require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
using .DynamicHMC, TransformVariables, LogDensityProblems
include("dynamichmc_inference.jl")
export dynamichmc_inference
end
end

export turing_inference, abc_inference

export turing_inference, stan_inference ,abc_inference
end # module
2 changes: 1 addition & 1 deletion src/dynamichmc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ function dynamichmc_inference(problem::DiffEqBase.DEProblem, algorithm, t, data,
= TransformedLogDensity(trans, P)
∇ℓ = LogDensityProblems.ADgradient(AD_gradient_kind, ℓ)
results = mcmc_with_warmup(rng, ∇ℓ, num_samples; mcmc_kwargs...)
merge((posterior = transform.(Ref(trans), results.chain), ), results)
merge((posterior = TransformVariables.transform.(Ref(trans), results.chain), ), results)
end
78 changes: 40 additions & 38 deletions src/stan_inference.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
struct StanModel{M,R,C,N}
struct StanResult{M,R,C}
model::M
return_code::R
chains::C
cnames::N
end

function Base.show(io::IO, mime::MIME"text/plain", res::StanResult)
show(io, mime, res.chains)
end

struct StanODEData
end

function generate_priors(n,priors)
priors_string = ""
if priors==nothing
if priors===nothing
for i in 1:n
priors_string = string(priors_string,"theta[$i] ~ normal(0, 1)", " ; ")
priors_string = string(priors_string,"theta_$i ~ normal(0, 1)", " ; ")
end
else
for i in 1:n
priors_string = string(priors_string,"theta[$i] ~",stan_string(priors[i]),";")
priors_string = string(priors_string,"theta_$i ~ ",stan_string(priors[i]),";")
end
end
priors_string
Expand All @@ -34,13 +37,13 @@ function generate_theta(n,priors)
lower_bound = string("lower=",minimum(priors[i]))
end
if lower_bound != "" && upper_bound != ""
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta$i",";")
theta = string(theta,"real","<$lower_bound",",","$upper_bound>"," theta_$i",";")
elseif lower_bound != ""
theta = string(theta,"real","<$lower_bound",">"," theta$i",";")
theta = string(theta,"real","<$lower_bound",">"," theta_$i",";")
elseif upper_bound != ""
theta = string(theta,"real","<","$upper_bound>"," theta$i",";")
theta = string(theta,"real","<","$upper_bound>"," theta_$i",";")
else
theta = string(theta,"real"," theta$i",";")
theta = string(theta,"real"," theta_$i",";")
end
end
return theta
Expand All @@ -50,9 +53,10 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
stanmodel = nothing;alg=:rk45,
num_samples=1000, num_warmup=1000, reltol=1e-3,
abstol=1e-6, maxiter=Int(1e5),likelihood=Normal,
vars=(StanODEData(),InverseGamma(3,3)),nchains=1,
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing, printsummary = true)

vars=(StanODEData(),InverseGamma(3,3)),nchains=[1],
sample_u0 = false, save_idxs = nothing, diffeq_string = nothing,
printsummary = true, output_format = :mcmcchains)

save_idxs !== nothing && length(save_idxs) == 1 ? save_idxs = save_idxs[1] : save_idxs = save_idxs
length_of_y = length(prob.u0)
save_idxs = something(save_idxs, 1:length_of_y)
Expand All @@ -63,24 +67,26 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
else
length_of_parameter = length(prob.p) + sample_u0 * length(save_idxs)
end

if stanmodel === nothing
if alg ==:adams
algorithm = "integrate_ode_adams"
algorithm = "ode_adams_tol"
elseif alg ==:rk45
algorithm = "integrate_ode_rk45"
algorithm = "ode_rk45_tol"
elseif alg == :bdf
algorithm = "integrate_ode_bdf"
algorithm = "ode_bdf_tol"
else
error("The choices for alg are :adams, :rk45, or :bdf")
end
hyper_params = ""
tuple_hyper_params = ""
setup_params = ""
thetas = ""
theta_names = ""
theta_string = generate_theta(length_of_parameter,priors)
for i in 1:length_of_parameter
thetas = string(thetas,"theta[$i] = theta$i",";")
thetas = string(thetas,"real theta_$i",";")
theta_names = string(theta_names,"theta_$i",",")
end
for i in 1:length_of_params
if isa(vars[i],StanODEData)
Expand All @@ -97,18 +103,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
stan_likelihood = stan_string(likelihood)
if sample_u0
nu = length(save_idxs)
dv_names_ind = findfirst("$nu", theta_names)[1]
if nu < length(prob.u0)
u0 = "{"
u0 = ""
for u_ in prob.u0[nu+1:length(prob.u0)]
u0 = u0*string(u_)
end
u0 = u0*"}"
integral_string = "u_hat = $algorithm(sho, append_array(theta[1:$nu],$u0), t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
else
integral_string = "u_hat = $algorithm(sho, theta[1:$nu], t0, ts, theta[$(nu+1):$length_of_parameter], x_r, x_i, $reltol, $abstol, $maxiter);"
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind]),$u0]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
else
integral_string = "u_hat = $algorithm(sho, [$(theta_names[1:dv_names_ind])]', t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names[dv_names_ind+2:end],',')));"
end
else
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, theta, x_r, x_i, $reltol, $abstol, $maxiter);"
integral_string = "u_hat = $algorithm(sho, u0, t0, ts, $reltol, $abstol, $maxiter, $(rstrip(theta_names,',')));"
end
binsearch_string = """
int bin_search(real x, int min_val, int max_val){
Expand All @@ -120,8 +126,8 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
out = mid_pt;
range = 0;
} else {
range = (range + 1) / 2;
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
range = (range + 1) / 2;
mid_pt = x > mid_pt ? mid_pt + range: mid_pt - range;
}
}
return out;
Expand All @@ -141,26 +147,18 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
$diffeq_string
}
data {
real u0[$length_of_y];
vector[$length_of_y] u0;
int<lower=1> T;
real internal_var___u[T,$(length(save_idxs))];
real t0;
real ts[T];
}
transformed data {
real x_r[0];
int x_i[0];
}
parameters {
$setup_params
$theta_string
}
transformed parameters{
real theta[$length_of_parameter];
$thetas
}
model{
real u_hat[T,$length_of_y];
vector[$length_of_y] u_hat[T];
$hyper_params
$priors_string
$integral_string
Expand All @@ -169,9 +167,13 @@ function stan_inference(prob::DiffEqBase.DEProblem,t,data,priors = nothing,
}
}
"
stanmodel = CmdStan.Stanmodel(num_samples=num_samples, num_warmup=num_warmup, name="parameter_estimation_model", model=parameter_estimation_model, nchains=nchains, printsummary = printsummary)
stanmodel = StanSample.SampleModel("parameter_estimation_model", parameter_estimation_model, nchains; printsummary = printsummary, method = StanSample.Sample(;num_samples = num_samples, num_warmup = num_warmup))
end
parameter_estimation_data = Dict("u0"=>prob.u0, "T" => length(t), "internal_var___u" => view(data, :, 1:length(t))', "t0" => prob.tspan[1], "ts" => t)
return_code, chains, cnames = CmdStan.stan(stanmodel, [parameter_estimation_data])
return StanModel(stanmodel, return_code, chains, cnames)
rc = stan_sample(stanmodel; data = parameter_estimation_data)
if success(rc)
return StanResult(stanmodel, rc, read_samples(stanmodel; output_format=output_format))
else
rc.err
end
end
4 changes: 1 addition & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ const GROUP = get(ENV, "GROUP", "All")
if GROUP == "All" || GROUP == "Core"
@time @safetestset "DynamicHMC" begin include("dynamicHMC.jl") end
@time @safetestset "Turing" begin include("turing.jl") end
@time @safetestset "ABC" begin include("abc.jl") end
# @time @safetestset "ABC" begin include("abc.jl") end
end

if GROUP == "Stan" || GROUP == "All"
using Pkg
Pkg.add("CmdStan")
@time @safetestset "Stan_String" begin include("stan_string.jl") end
@time @safetestset "Stan" begin include("stan.jl") end
end
32 changes: 13 additions & 19 deletions test/stan.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CmdStan, DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
using DiffEqBayes, OrdinaryDiffEq, ParameterizedFunctions,
RecursiveArrayTools, Distributions, Test

println("One parameter case")
Expand All @@ -19,24 +19,21 @@ priors = [truncated(Normal(1.5,0.1),1.0,1.8)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=3e-1

# Test norecompile
bayesian_result2 = stan_inference(prob1,t,data,priors,bayesian_result.model;
num_samples=300,num_warmup=500,likelihood=Normal)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
@test mean(get(bayesian_result2.chains,:theta_1)[1]) 1.5 atol=3e-1

priors = [truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.,0.01),0.5,2.0),truncated(Normal(1.5,0.01),1.0,2.0)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,sample_u0=true)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1. atol=3e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] 1. atol=3e-1
@test sdf[sdf.parameters .== :theta3, :mean][1] 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_3)[1]) 1.5 atol=3e-1

sol = solve(prob1,Tsit5(),save_idxs=[1])
randomized = VectorOfArray([(sol(t[i]) + .01 * randn(1)) for i in 1:length(t)])
Expand All @@ -45,17 +42,15 @@ priors = [truncated(Normal(1.5,0.1),0.5,2)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,save_idxs=[1])

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=3e-1


priors = [truncated(Normal(1.,0.01),0.5,2),truncated(Normal(1.5,0.01),0.5,2)]
bayesian_result = stan_inference(prob1,t,data,priors;num_samples=300,
num_warmup=500,likelihood=Normal,save_idxs=[1],sample_u0=true)

sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1. atol=3e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] 1.5 atol=3e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1. atol=3e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1.5 atol=3e-1

println("Four parameter case")
f1 = @ode_def begin
Expand All @@ -74,8 +69,7 @@ priors = [truncated(Normal(1.5,0.01),0.5,2),truncated(Normal(1.0,0.01),0.5,1.5),
truncated(Normal(3.0,0.01),0.5,4),truncated(Normal(1.0,0.01),0.5,2)]

bayesian_result = stan_inference(prob1,t,data,priors;num_samples=100,num_warmup=500,vars =(DiffEqBayes.StanODEData(),InverseGamma(4,1)))
sdf = CmdStan.read_summary(bayesian_result.model)
@test sdf[sdf.parameters .== :theta1, :mean][1] 1.5 atol=1e-1
@test sdf[sdf.parameters .== :theta2, :mean][1] 1.0 atol=1e-1
@test sdf[sdf.parameters .== :theta3, :mean][1] 3.0 atol=1e-1
@test sdf[sdf.parameters .== :theta4, :mean][1] 1.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_1)[1]) 1.5 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_2)[1]) 1.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_3)[1]) 3.0 atol=1e-1
@test mean(get(bayesian_result.chains,:theta_4)[1]) 1.0 atol=1e-1
Loading

0 comments on commit ea7fa7c

Please sign in to comment.