Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "2.13.2"
version = "2.13.3"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand Down
17 changes: 13 additions & 4 deletions src/saving.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,13 @@ mutable struct SavingAffect{SaveFunc, tType, savevalType, saveatType, saveatCach
saveat_cache::saveatCacheType
save_everystep::Bool
save_start::Bool
save_end::Bool
saveiter::Int
end

function (affect!::SavingAffect)(integrator,force_save = false)

just_saved = false
# see OrdinaryDiffEq.jl -> integrator_utils.jl, function savevalues!
while !isempty(affect!.saveat) && integrator.tdir*top(affect!.saveat) <= integrator.tdir*integrator.t # Perform saveat
affect!.saveiter += 1
Expand All @@ -60,11 +63,15 @@ function (affect!::SavingAffect)(integrator,force_save = false)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter,
affect!.save_func(curu, curt, integrator),Val{false})
else # ==t, just save
just_saved = true
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
end
end
if affect!.save_everystep || force_save
if !just_saved &&
affect!.save_everystep || force_save ||
(affect!.save_end && integrator.t == integrator.sol.prob.tspan[end])

affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.u, integrator.t, integrator),Val{false})
Expand All @@ -89,7 +96,8 @@ end
SavingCallback(save_func, saved_values::SavedValues;
saveat=Vector{eltype(saved_values.t)}(),
save_everystep=isempty(saveat),
save_start = true,
save_start = save_everystep || isempty(saveat) || saveat isa Number,
save_end = save_everystep || isempty(saveat) || saveat isa Number,
tdir=1)

A `DiscreteCallback` applied after every step, saving the time `t` and the value
Expand All @@ -104,7 +112,8 @@ If the time `tdir` direction is not positive, i.e. `tspan[1] > tspan[2]`,
function SavingCallback(save_func, saved_values::SavedValues;
saveat=Vector{eltype(saved_values.t)}(),
save_everystep=isempty(saveat),
save_start = true,
save_start = save_everystep || isempty(saveat) || saveat isa Number,
save_end = save_everystep || isempty(saveat) || saveat isa Number,
tdir=1)
# saveat conversions, see OrdinaryDiffEq.jl -> integrators/type.jl
saveat_vec = collect(saveat)
Expand All @@ -113,7 +122,7 @@ function SavingCallback(save_func, saved_values::SavedValues;
else
saveat_internal = BinaryMaxHeap(saveat_vec)
end
affect! = SavingAffect(save_func, saved_values, saveat_internal, saveat_vec, save_everystep, save_start, 0)
affect! = SavingAffect(save_func, saved_values, saveat_internal, saveat_vec, save_everystep, save_start, save_end, 0)
condtion = (u, t, integrator) -> true
DiscreteCallback(condtion, affect!;
initialize = saving_initialize,
Expand Down
10 changes: 10 additions & 0 deletions test/saving_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,13 @@ saved_values = SavedValues(Float64, Tuple{Float64,Float64})
cb = SavingCallback((u,t,integrator)->(tr(u),norm(u)), saved_values, saveat=0.0:0.1:1.0)
sol = solve(prob, Tsit5(), callback=cb)
println(saved_values.saveval)

# Save only end
prob = ODEProblem((du,u,p,t) -> du .= u, rand(4,4), (0.0,1.0))
saved_values = SavedValues(Float64, Tuple{Float64,Float64})
cb = SavingCallback((u,t,integrator)->(tr(u),norm(u)), saved_values,
save_everystep = false, save_start = false)
sol = solve(prob, Tsit5(), callback=cb)
print(saved_values.saveval)
@test length(saved_values.t) == 1
@test saved_values.t[1] == 1.0