diff --git a/README.md b/README.md index 311e4b54..25aecb92 100644 --- a/README.md +++ b/README.md @@ -94,3 +94,23 @@ StepsizeLimiter(dtFE;safety_factor=9//10,max_step=false,cached_dtcache=0.0) which defaults to `9//10`. `max_step=true` makes every step equal to `safety_factor*dtFE(t,u)` when the solver is set to `adaptive=false`. `cached_dtcache` should be set to match the type for time when not using Float64 values. + +## SavingCallback + +The aving callback lets you define a function `save_func(t, u, integrator)` which +returns quantities of interest that shall be saved. The constructor is: + +```julia +SavingCallback(save_func, saved_values::SavedValues; + saveat=Vector{eltype(saved_values.t)}(), + save_everystep=isempty(saveat), + tdir=1) +``` +- `save_func(t, u, integrator)` returns the quantities which shall be saved. +- `saved_values::SavedValues` contains vectors `t::Vector{tType}`, + `saveval::Vector{savevalType}` of the saved quantities. Here, + `save_func(t, u, integrator)::savevalType`. +- `saveat` Mimicks `saveat` in `solve` for ODEs. +- `save_everystep` Mimicks `save_everystep` in `solve` for ODEs. +- `tdir` should be `sign(tspan[end]-tspan[1])`. It defaults to `1` and should + be adapted if `tspan[1] > tspan[end]`. diff --git a/REQUIRE b/REQUIRE index c6cfaf8d..eebd714a 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,6 +1,8 @@ julia 0.6.0 DiffEqBase 0.6.0 NLsolve -ForwardDiff 0.5.0 +ForwardDiff 0.5 0.6- DiffBase OrdinaryDiffEq +RecursiveArrayTools +DataStructures diff --git a/src/DiffEqCallbacks.jl b/src/DiffEqCallbacks.jl index e51bb96e..672f9ab8 100644 --- a/src/DiffEqCallbacks.jl +++ b/src/DiffEqCallbacks.jl @@ -2,14 +2,14 @@ __precompile__() module DiffEqCallbacks - using DiffEqBase, NLsolve, ForwardDiff - import DiffBase + using DiffEqBase, NLsolve, ForwardDiff, RecursiveArrayTools, DataStructures - import OrdinaryDiffEq: fix_dt_at_bounds!, modify_dt_for_tstops! + import OrdinaryDiffEq: fix_dt_at_bounds!, modify_dt_for_tstops!, ode_addsteps!, ode_interpolant include("autoabstol.jl") include("manifold.jl") include("domain.jl") include("stepsizelimiters.jl") + include("saving.jl") end # module diff --git a/src/manifold.jl b/src/manifold.jl index 9b6cdd50..7d83c486 100644 --- a/src/manifold.jl +++ b/src/manifold.jl @@ -11,11 +11,10 @@ Base.@pure function determine_chunksize(u,CS) end function autodiff_setup(f!, initial_x, chunk_size::Type{Val{CS}}) where CS - permf! = (fx, x) -> f!(reshape(x,size(initial_x)...), fx) fx2 = copy(initial_x) - jac_cfg = ForwardDiff.JacobianConfig(nothing, initial_x, initial_x, ForwardDiff.Chunk{CS}()) + jac_cfg = ForwardDiff.JacobianConfig(permf!, initial_x, initial_x, ForwardDiff.Chunk{CS}()) g! = (x, gx) -> ForwardDiff.jacobian!(gx, permf!, fx2, x, jac_cfg) fg! = (x, fx, gx) -> begin diff --git a/src/saving.jl b/src/saving.jl new file mode 100644 index 00000000..18bd5656 --- /dev/null +++ b/src/saving.jl @@ -0,0 +1,100 @@ +""" + SavedValues{tType<:Real, savevalType} + +A struct used to save values of the time in `t::Vector{tType}` and +additional values in `saveval::Vector{savevalType}`. +""" +struct SavedValues{tType<:Real, savevalType} + t::Vector{tType} + saveval::Vector{savevalType} +end + +""" + SavedValues(tType::DataType, savevalType::DataType) + +Return `SavedValues{tType, savevalType}` with empty storage vectors. +""" +function SavedValues(tType::DataType, savevalType::DataType) + SavedValues{tType, savevalType}(Vector{tType}(), Vector{savevalType}()) +end + +function Base.show(io::IO, saved_values::SavedValues) + tType = eltype(saved_values.t) + savevalType = eltype(saved_values.saveval) + print(io, "SavedValues{tType=", tType, ", savevalType=", savevalType, "}", + "\nt:\n", saved_values.t, "\nsaveval:\n", saved_values.saveval) +end + + +mutable struct SavingAffect{SaveFunc, tType, savevalType, saveatType} + save_func::SaveFunc + saved_values::SavedValues{tType, savevalType} + saveat::saveatType + save_everystep::Bool + saveiter::Int +end + +function (affect!::SavingAffect)(integrator) + # see OrdinaryDiffEq.jl -> integrator_utils.jl, function savevalues! + while !isempty(affect!.saveat) && integrator.tdir*top(affect!.saveat) <= integrator.tdir*integrator.t # Perform saveat + affect!.saveiter += 1 + curt = pop!(affect!.saveat) # current time + if curt != integrator.t # If tspan[2]`, +`tdir = -1` has to be specified. +""" +function SavingCallback(save_func, saved_values::SavedValues; + saveat=Vector{eltype(saved_values.t)}(), + save_everystep=isempty(saveat), + tdir=1) + # saveat conversions, see OrdinaryDiffEq.jl -> integrators/type.jl + saveat_vec = collect(saveat) + if tdir > 0 + saveat_internal = binary_minheap(saveat_vec) + else + saveat_internal = binary_maxheap(saveat_vec) + end + affect! = SavingAffect(save_func, saved_values, saveat_internal, save_everystep, 0) + condtion = (t, u, integrator) -> true + DiscreteCallback(condtion, affect!; + initialize = saving_initialize, + save_positions=(false,false)) +end + + +export SavingCallback, SavedValues diff --git a/test/runtests.jl b/test/runtests.jl index b81a0618..b7853dc5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,4 +7,5 @@ tic() @time @testset "Domain tests" begin include("domain_tests.jl") end @time @testset "Manifold tests" begin include("manifold_tests.jl") end @time @testset "StepsizeLimiter tests" begin include("stepsizelimiter_tests.jl") end +@time @testset "Saving tests" begin include("saving_tests.jl") end toc() diff --git a/test/saving_tests.jl b/test/saving_tests.jl new file mode 100644 index 00000000..519f35a7 --- /dev/null +++ b/test/saving_tests.jl @@ -0,0 +1,72 @@ +using Base.Test, OrdinaryDiffEq, DiffEqProblemLibrary, DiffEqCallbacks + +# save_everystep, scalar problem +prob = prob_ode_linear +saved_values = SavedValues(Float64, Float64) +cb = SavingCallback((t,u,integrator)->u, saved_values) +sol = solve(prob, Tsit5(), callback=cb) +print("\n", saved_values, "\n") +@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> sol.u[idx] == saved_values.saveval[idx], eachindex(saved_values.t)) + +# save_everystep, inplace problem +prob2D = prob_ode_2Dlinear +saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0)) +cb = SavingCallback((t,u,integrator)->copy(u), saved_values) +sol = solve(prob2D, Tsit5(), callback=cb) +@test all(idx -> sol.t[idx] .== saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> all(sol.u[idx] .== saved_values.saveval[idx]), eachindex(saved_values.t)) + +saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0)) +cb = SavingCallback((t,u,integrator)->u[1], saved_values) +sol = solve(prob2D, Tsit5(), callback=cb) +@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> sol.u[idx][1] == saved_values.saveval[idx], eachindex(saved_values.t)) + +# saveat, scalar problem +saved_values = SavedValues(Float64, Float64) +saveat = linspace(prob.tspan..., 10) +cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat) +sol = solve(prob, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) + +# saveat, inplace problem +saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0)) +saveat = linspace(prob2D.tspan..., 10) +cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat) +sol = solve(prob2D, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) + +saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0)) +saveat = linspace(prob2D.tspan..., 10) +cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat) +sol = solve(prob2D, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) + +# saveat, tdir<0, scalar problem +prob_inverse = ODEProblem(prob.f, prob.u0, (prob.tspan[end], prob.tspan[1])) +saved_values = SavedValues(Float64, Float64) +saveat = linspace(prob_inverse.tspan..., 10) +cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat, tdir=-1) +sol = solve(prob_inverse, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) + +# saveat, tdir<0, inplace problem +prob2D_inverse = ODEProblem(prob2D.f, prob2D.u0, (prob2D.tspan[end], prob2D.tspan[1])) +saved_values = SavedValues(eltype(prob2D_inverse.tspan), typeof(prob2D_inverse.u0)) +saveat = linspace(prob2D_inverse.tspan..., 10) +cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat, tdir=-1) +sol = solve(prob2D_inverse, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) + +saved_values = SavedValues(eltype(prob2D_inverse.tspan), eltype(prob2D_inverse.u0)) +saveat = linspace(prob2D_inverse.tspan..., 10) +cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat, tdir=-1) +sol = solve(prob2D_inverse, Tsit5(), callback=cb) +@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) +@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))