Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 8 additions & 31 deletions src/common_interface/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator,

tmp = integrator(curt)
save_value!(integrator.sol.u, tmp, uType,
integrator.opts.save_idxs, Val{false})
integrator.opts.save_idxs, false)
push!(integrator.sol.t, curt)
if integrator.opts.dense
tmp = integrator(curt, Val{1})
save_value!(integrator.sol.interp.du, tmp, uType,
integrator.opts.save_idxs, Val{false})
integrator.opts.save_idxs, false)
end
end

Expand All @@ -69,7 +69,7 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator,
integrator.opts.save_idxs)
push!(integrator.sol.t, integrator.t)
if integrator.opts.dense
tmp = integrator(integrator.t, Val{1})
tmp = DiffEqBase.get_du(integrator)
save_value!(integrator.sol.interp.du, tmp, uType,
integrator.opts.save_idxs)
end
Expand All @@ -78,37 +78,14 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator,
return saved, savedexactly
end

function save_value!(save_array,
val,
::Type{T},
save_idxs,
make_copy::Type{Val{bool}} = Val{true}) where {T <: Number, bool}
function save_value!(save_array, val, ::Type{<:Number}, save_idxs, make_copy::Bool = true)
push!(save_array, first(val))
end
function save_value!(save_array,
val,
::Type{T},
save_idxs,
make_copy::Type{Val{bool}} = Val{true}) where {T <: Vector, bool}
@assert val isa Array
function save_value!(save_array, val, ::Type{<:AbstractArray}, save_idxs, make_copy::Bool = true)
save = if save_idxs !== nothing
val[save_idxs]
else
bool ? copy(val) : val
end
push!(save_array, save)
end
function save_value!(save_array,
val,
::Type{T},
save_idxs,
make_copy::Type{Val{bool}} = Val{true}) where {T <: AbstractArray, bool
}
@assert val isa Array
save = if save_idxs !== nothing
val[save_idxs]
else
x = bool ? copy(val) : val
make_copy ? copy(val) : val
end
push!(save_array, save)
end
Expand Down Expand Up @@ -151,11 +128,11 @@ function DiffEqBase.terminate!(integrator::AbstractSundialsIntegrator,
integrator.opts.tstops.valtree = typeof(integrator.opts.tstops.valtree)()
end

@inline function DiffEqBase.get_du(integrator::CVODEIntegrator)
@inline function DiffEqBase.get_du(integrator::AbstractSundialsIntegrator)
integrator(integrator.t, Val{1})
end

@inline function DiffEqBase.get_du!(out, integrator::CVODEIntegrator)
@inline function DiffEqBase.get_du!(out, integrator::AbstractSundialsIntegrator)
integrator(out, integrator.t, Val{1})
end

Expand Down
24 changes: 12 additions & 12 deletions src/common_interface/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
callback_cache = nothing
end

tspan = prob.tspan
tspan = Float64.(prob.tspan)
t0 = tspan[1]

tdir = sign(tspan[2] - tspan[1])
Expand Down Expand Up @@ -213,9 +213,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i

flag = CVodeInit(mem, getcfunf(userfun), t0, utmp)

dt !== nothing && (flag = CVodeSetInitStep(mem, dt))
flag = CVodeSetMinStep(mem, dtmin)
flag = CVodeSetMaxStep(mem, dtmax)
dt !== nothing && (flag = CVodeSetInitStep(mem, Float64(dt)))
flag = CVodeSetMinStep(mem, Float64(dtmin))
flag = CVodeSetMaxStep(mem, Float64(dtmax))
flag = CVodeSetUserData(mem, userfun)
if abstol isa Array
flag = CVodeSVtolerances(mem, reltol, abstol)
Expand Down Expand Up @@ -611,9 +611,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i
end
end

dt !== nothing && (flag = ARKStepSetInitStep(mem, dt))
flag = ARKStepSetMinStep(mem, dtmin)
flag = ARKStepSetMaxStep(mem, dtmax)
dt !== nothing && (flag = ARKStepSetInitStep(mem, Float64(dt)))
flag = ARKStepSetMinStep(mem, Float64(dtmin))
flag = ARKStepSetMaxStep(mem, Float64(dtmax))
flag = ARKStepSetUserData(mem, userfun)
if abstol isa Array
flag = ARKStepSVtolerances(mem, reltol, abstol)
Expand Down Expand Up @@ -1411,7 +1411,7 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free =
solver_step(integrator, tstop)
integrator.t = first(integrator.tout)
integrator.flag < 0 && break
handle_callbacks!(integrator)
handle_callbacks!(integrator) # this also updates the interpolation
integrator.flag < 0 && break
if isempty(integrator.opts.tstops)
break
Expand All @@ -1426,14 +1426,14 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free =
handle_tstop!(integrator)
end

tend = integrator.t
if integrator.opts.save_end &&
(isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t)
(isempty(integrator.sol.t) || integrator.sol.t[end] != tend)
save_value!(integrator.sol.u, integrator.u, uType,
integrator.opts.save_idxs)
push!(integrator.sol.t, integrator.t)
push!(integrator.sol.t, tend)
if integrator.opts.dense
integrator(integrator.u, integrator.t, Val{1})
save_value!(integrator.sol.interp.du, integrator.u, uType,
save_value!(integrator.sol.interp.du, get_du(integrator), uType,
integrator.opts.save_idxs)
end
end
Expand Down
25 changes: 25 additions & 0 deletions test/interpolation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using Sundials, Test, DiffEqBase
using ForwardDiff
import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear

function regression_test(alg, tol_ode_linear, tol_ode_2Dlinear)
sol = solve(prob_ode_linear, alg, dense = true, abstol=1e-8, reltol=1e-8)
@inferred sol(.5)
u0 = sol[1]
p = sol.prob.p
for t in 0.0:1/16:1.0
@test isapprox(u0 * exp(p*t), sol(t), rtol=tol_ode_linear)
end

sol = solve(prob_ode_2Dlinear, alg, dt = 1 / 2^(2), dense = true)
sol2 = solve(prob_ode_2Dlinear, alg, dense = true, abstol=1e-8, reltol=1e-8)
u0 = sol[1]
p = sol.prob.p
for t in 0.0:1/16:1.0
@test isapprox(u0 .* exp(p*t), sol(t), rtol=tol_ode_2Dlinear)
end
end

regression_test(ARKODE(), 1e-5, 1e-4)
regression_test(CVODE_BDF(), 1e-6, 1e-2)
regression_test(CVODE_Adams(), 1e-6, 1e-3)
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,5 @@ end
@testset "Mass Matrix" begin include("common_interface/mass_matrix.jl") end
@testset "Preconditioners" begin include("common_interface/precs.jl") end
end

@testset "Interpolation" begin include("interpolation.jl") end