Skip to content

Commit

Permalink
Fix save_end overriding behavior
Browse files Browse the repository at this point in the history
Fixes #1842
  • Loading branch information
ChrisRackauckas committed Dec 28, 2023
1 parent f1b8d90 commit 40dfbd1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
5 changes: 5 additions & 0 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ function _savevalues!(integrator, force_save, reduce_size)::Tuple{Bool, Bool}
integrator.cache.current)
end
else # ==t, just save
if curt == integrator.sol.prob.tspan[2] && !integrator.opts.save_end
integrator.saveiter -= 1
continue
end
savedexactly = true
copyat_or_push!(integrator.sol.t, integrator.saveiter, integrator.t)
if integrator.opts.save_idxs === nothing
Expand Down Expand Up @@ -145,6 +149,7 @@ postamble!(integrator::ODEIntegrator) = _postamble!(integrator)
function _postamble!(integrator)
DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator)
solution_endpoint_match_cur_integrator!(integrator)
save
resize!(integrator.sol.t, integrator.saveiter)
resize!(integrator.sol.u, integrator.saveiter)
if !(integrator.sol isa DAESolution)
Expand Down
13 changes: 10 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,16 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,
sizehint!(ts, 50)
sizehint!(ks, 50)
elseif !isempty(saveat_internal)
sizehint!(timeseries, length(saveat_internal) + 1)
sizehint!(ts, length(saveat_internal) + 1)
sizehint!(ks, length(saveat_internal) + 1)
savelength = length(saveat_internal) + 1
if save_start == false
savelength -= 1
end
if save_end == false && prob.tspan[2] in saveat_internal.valtree
savelength -= 1
end
sizehint!(timeseries, savelength)
sizehint!(ts, savelength)
sizehint!(ks, savelength)
else
sizehint!(timeseries, 2)
sizehint!(ts, 2)
Expand Down
25 changes: 25 additions & 0 deletions test/interface/ode_saveat_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,28 @@ prob = ODEProblem(SIR!, [0.99, 0.01, 0.0], (t_obs[1], t_obs[end]), [0.20, 0.15])
sol = solve(prob, DP5(), reltol = 1e-6, abstol = 1e-6, saveat = t_obs)
@test maximum(sol) <= 1
@test minimum(sol) >= 0

@testset "Proper save_start and save_end behavior" begin
function f2(du, u, p, t)
du[1] = -cos(u[1]) * u[1]
end
prob = ODEProblem(f2, [10], (0.0, 0.4))

@test solve(prob, Tsit5(); saveat = 0:.1:.4).t == [0.0; 0.1; 0.2; 0.3; 0.4]
@test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = true, save_end = true).t == [0.0; 0.1; 0.2; 0.3; 0.4]
@test solve(prob, Tsit5(); saveat = 0:.1:.4, save_start = false, save_end = false).t == [0.1; 0.2; 0.3]

ts = solve(prob, Tsit5()).t
@test 0.0 in ts
@test 0.4 in ts
ts = solve(prob, Tsit5(); save_start = true, save_end = true).t
@test 0.0 in ts
@test 0.4 in ts
ts = solve(prob, Tsit5(); save_start = false, save_end = false).t
@test 0.0 ts
@test 0.4 ts

@test solve(prob, Tsit5(); saveat = [.2]).t == [0.2]
@test solve(prob, Tsit5(); saveat = [.2], save_start = true, save_end = true).t == [0.0; 0.2; 0.4]
@test solve(prob, Tsit5(); saveat = [.2], save_start = false, save_end = false).t == [0.2]
end

0 comments on commit 40dfbd1

Please sign in to comment.