From 2603b145780a24f796085f969f40999b80b16007 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Wed, 21 Jun 2023 17:28:33 -0400 Subject: [PATCH 1/3] This fixes https://github.com/SciML/Sundials.jl/issues/292. It turns out we were asking for interpolated derivatives for essentially no reason (and doing so caused problems whenever you hit callbacks). As a nice side-benefit, this removes the 200 warnings --- src/common_interface/integrator_utils.jl | 6 +++--- src/common_interface/solve.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index 365f02e6..c4c826fa 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -69,7 +69,7 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, integrator.opts.save_idxs) push!(integrator.sol.t, integrator.t) if integrator.opts.dense - tmp = integrator(integrator.t, Val{1}) + tmp = DiffEqBase.get_du(integrator) save_value!(integrator.sol.interp.du, tmp, uType, integrator.opts.save_idxs) end @@ -151,11 +151,11 @@ function DiffEqBase.terminate!(integrator::AbstractSundialsIntegrator, integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)() end -@inline function DiffEqBase.get_du(integrator::CVODEIntegrator) +@inline function DiffEqBase.get_du(integrator::AbstractSundialsIntegrator) integrator(integrator.t, Val{1}) end -@inline function DiffEqBase.get_du!(out, integrator::CVODEIntegrator) +@inline function DiffEqBase.get_du!(out, integrator::AbstractSundialsIntegrator) integrator(out, integrator.t, Val{1}) end diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 1d43dd39..02e45145 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -1432,7 +1432,6 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = integrator.opts.save_idxs) push!(integrator.sol.t, integrator.t) if integrator.opts.dense - integrator(integrator.u, integrator.t, Val{1}) save_value!(integrator.sol.interp.du, integrator.u, uType, integrator.opts.save_idxs) end From 7c3571798bf4ca1110e3caa3144af88354ef4ed1 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 27 Jun 2023 11:48:40 -0400 Subject: [PATCH 2/3] cleanup --- src/common_interface/integrator_utils.jl | 29 +++--------------------- src/common_interface/solve.jl | 23 ++++++++++--------- 2 files changed, 15 insertions(+), 37 deletions(-) diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index c4c826fa..c54674f5 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -78,37 +78,14 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, return saved, savedexactly end -function save_value!(save_array, - val, - ::Type{T}, - save_idxs, - make_copy::Type{Val{bool}} = Val{true}) where {T <: Number, bool} +function save_value!(save_array, val, ::Type{<:Number}, save_idxs, make_copy::Bool = true) push!(save_array, first(val)) end -function save_value!(save_array, - val, - ::Type{T}, - save_idxs, - make_copy::Type{Val{bool}} = Val{true}) where {T <: Vector, bool} - @assert val isa Array +function save_value!(save_array, val, ::Type{<:AbstractArray}, save_idxs, make_copy::Bool = true) save = if save_idxs !== nothing val[save_idxs] else - bool ? copy(val) : val - end - push!(save_array, save) -end -function save_value!(save_array, - val, - ::Type{T}, - save_idxs, - make_copy::Type{Val{bool}} = Val{true}) where {T <: AbstractArray, bool - } - @assert val isa Array - save = if save_idxs !== nothing - val[save_idxs] - else - x = bool ? copy(val) : val + make_copy ? copy(val) : val end push!(save_array, save) end diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index 02e45145..86bd6fed 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -140,7 +140,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i callback_cache = nothing end - tspan = prob.tspan + tspan = Float64.(prob.tspan) t0 = tspan[1] tdir = sign(tspan[2] - tspan[1]) @@ -213,9 +213,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i flag = CVodeInit(mem, getcfunf(userfun), t0, utmp) - dt !== nothing && (flag = CVodeSetInitStep(mem, dt)) - flag = CVodeSetMinStep(mem, dtmin) - flag = CVodeSetMaxStep(mem, dtmax) + dt !== nothing && (flag = CVodeSetInitStep(mem, Float64(dt))) + flag = CVodeSetMinStep(mem, Float64(dtmin)) + flag = CVodeSetMaxStep(mem, Float64(dtmax)) flag = CVodeSetUserData(mem, userfun) if abstol isa Array flag = CVodeSVtolerances(mem, reltol, abstol) @@ -611,9 +611,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i end end - dt !== nothing && (flag = ARKStepSetInitStep(mem, dt)) - flag = ARKStepSetMinStep(mem, dtmin) - flag = ARKStepSetMaxStep(mem, dtmax) + dt !== nothing && (flag = ARKStepSetInitStep(mem, Float64(dt))) + flag = ARKStepSetMinStep(mem, Float64(dtmin)) + flag = ARKStepSetMaxStep(mem, Float64(dtmax)) flag = ARKStepSetUserData(mem, userfun) if abstol isa Array flag = ARKStepSVtolerances(mem, reltol, abstol) @@ -1411,7 +1411,7 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = solver_step(integrator, tstop) integrator.t = first(integrator.tout) integrator.flag < 0 && break - handle_callbacks!(integrator) + handle_callbacks!(integrator) # this also updates the interpolation integrator.flag < 0 && break if isempty(integrator.opts.tstops) break @@ -1426,13 +1426,14 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = handle_tstop!(integrator) end + tend = integrator.t if integrator.opts.save_end && - (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t) + (isempty(integrator.sol.t) || integrator.sol.t[end] != tend) save_value!(integrator.sol.u, integrator.u, uType, integrator.opts.save_idxs) - push!(integrator.sol.t, integrator.t) + push!(integrator.sol.t, tend) if integrator.opts.dense - save_value!(integrator.sol.interp.du, integrator.u, uType, + save_value!(integrator.sol.interp.du, get_du(integrator), uType, integrator.opts.save_idxs) end end From cae63242c3d23e9d130afec5f05390ccb4ed8c45 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Tue, 27 Jun 2023 14:16:09 -0400 Subject: [PATCH 3/3] add tests and fixes --- src/common_interface/integrator_utils.jl | 4 ++-- test/interpolation.jl | 25 ++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 test/interpolation.jl diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index c54674f5..6c6b0e4a 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -52,12 +52,12 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, tmp = integrator(curt) save_value!(integrator.sol.u, tmp, uType, - integrator.opts.save_idxs, Val{false}) + integrator.opts.save_idxs, false) push!(integrator.sol.t, curt) if integrator.opts.dense tmp = integrator(curt, Val{1}) save_value!(integrator.sol.interp.du, tmp, uType, - integrator.opts.save_idxs, Val{false}) + integrator.opts.save_idxs, false) end end diff --git a/test/interpolation.jl b/test/interpolation.jl new file mode 100644 index 00000000..7d242f60 --- /dev/null +++ b/test/interpolation.jl @@ -0,0 +1,25 @@ +using Sundials, Test, DiffEqBase +using ForwardDiff +import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear + +function regression_test(alg, tol_ode_linear, tol_ode_2Dlinear) + sol = solve(prob_ode_linear, alg, dense = true, abstol=1e-8, reltol=1e-8) + @inferred sol(.5) + u0 = sol[1] + p = sol.prob.p + for t in 0.0:1/16:1.0 + @test isapprox(u0 * exp(p*t), sol(t), rtol=tol_ode_linear) + end + + sol = solve(prob_ode_2Dlinear, alg, dt = 1 / 2^(2), dense = true) + sol2 = solve(prob_ode_2Dlinear, alg, dense = true, abstol=1e-8, reltol=1e-8) + u0 = sol[1] + p = sol.prob.p + for t in 0.0:1/16:1.0 + @test isapprox(u0 .* exp(p*t), sol(t), rtol=tol_ode_2Dlinear) + end +end + +regression_test(ARKODE(), 1e-5, 1e-4) +regression_test(CVODE_BDF(), 1e-6, 1e-2) +regression_test(CVODE_Adams(), 1e-6, 1e-3) diff --git a/test/runtests.jl b/test/runtests.jl index 3a32188b..b876dd4f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,3 +41,5 @@ end @testset "Mass Matrix" begin include("common_interface/mass_matrix.jl") end @testset "Preconditioners" begin include("common_interface/precs.jl") end end + +@testset "Interpolation" begin include("interpolation.jl") end