Skip to content

Commit

Permalink
Merge pull request #18 from ranocha/pull-request/5a0b3952
Browse files Browse the repository at this point in the history
WIP: SavingCallback
  • Loading branch information
ChrisRackauckas committed Sep 30, 2017
2 parents 6316547 + 38326de commit f962c26
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 6 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,23 @@ StepsizeLimiter(dtFE;safety_factor=9//10,max_step=false,cached_dtcache=0.0)
which defaults to `9//10`. `max_step=true` makes every step equal to
`safety_factor*dtFE(t,u)` when the solver is set to `adaptive=false`. `cached_dtcache`
should be set to match the type for time when not using Float64 values.

## SavingCallback

The aving callback lets you define a function `save_func(t, u, integrator)` which
returns quantities of interest that shall be saved. The constructor is:

```julia
SavingCallback(save_func, saved_values::SavedValues;
saveat=Vector{eltype(saved_values.t)}(),
save_everystep=isempty(saveat),
tdir=1)
```
- `save_func(t, u, integrator)` returns the quantities which shall be saved.
- `saved_values::SavedValues` contains vectors `t::Vector{tType}`,
`saveval::Vector{savevalType}` of the saved quantities. Here,
`save_func(t, u, integrator)::savevalType`.
- `saveat` Mimicks `saveat` in `solve` for ODEs.
- `save_everystep` Mimicks `save_everystep` in `solve` for ODEs.
- `tdir` should be `sign(tspan[end]-tspan[1])`. It defaults to `1` and should
be adapted if `tspan[1] > tspan[end]`.
4 changes: 3 additions & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
julia 0.6.0
DiffEqBase 0.6.0
NLsolve
ForwardDiff 0.5.0
ForwardDiff 0.5 0.6-
DiffBase
OrdinaryDiffEq
RecursiveArrayTools
DataStructures
6 changes: 3 additions & 3 deletions src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@ __precompile__()

module DiffEqCallbacks

using DiffEqBase, NLsolve, ForwardDiff
import DiffBase
using DiffEqBase, NLsolve, ForwardDiff, RecursiveArrayTools, DataStructures

import OrdinaryDiffEq: fix_dt_at_bounds!, modify_dt_for_tstops!
import OrdinaryDiffEq: fix_dt_at_bounds!, modify_dt_for_tstops!, ode_addsteps!, ode_interpolant

include("autoabstol.jl")
include("manifold.jl")
include("domain.jl")
include("stepsizelimiters.jl")
include("saving.jl")

end # module
3 changes: 1 addition & 2 deletions src/manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@ Base.@pure function determine_chunksize(u,CS)
end

function autodiff_setup(f!, initial_x, chunk_size::Type{Val{CS}}) where CS

permf! = (fx, x) -> f!(reshape(x,size(initial_x)...), fx)

fx2 = copy(initial_x)
jac_cfg = ForwardDiff.JacobianConfig(nothing, initial_x, initial_x, ForwardDiff.Chunk{CS}())
jac_cfg = ForwardDiff.JacobianConfig(permf!, initial_x, initial_x, ForwardDiff.Chunk{CS}())
g! = (x, gx) -> ForwardDiff.jacobian!(gx, permf!, fx2, x, jac_cfg)

fg! = (x, fx, gx) -> begin
Expand Down
100 changes: 100 additions & 0 deletions src/saving.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
SavedValues{tType<:Real, savevalType}
A struct used to save values of the time in `t::Vector{tType}` and
additional values in `saveval::Vector{savevalType}`.
"""
struct SavedValues{tType<:Real, savevalType}
t::Vector{tType}
saveval::Vector{savevalType}
end

"""
SavedValues(tType::DataType, savevalType::DataType)
Return `SavedValues{tType, savevalType}` with empty storage vectors.
"""
function SavedValues(tType::DataType, savevalType::DataType)
SavedValues{tType, savevalType}(Vector{tType}(), Vector{savevalType}())
end

function Base.show(io::IO, saved_values::SavedValues)
tType = eltype(saved_values.t)
savevalType = eltype(saved_values.saveval)
print(io, "SavedValues{tType=", tType, ", savevalType=", savevalType, "}",
"\nt:\n", saved_values.t, "\nsaveval:\n", saved_values.saveval)
end


mutable struct SavingAffect{SaveFunc, tType, savevalType, saveatType}
save_func::SaveFunc
saved_values::SavedValues{tType, savevalType}
saveat::saveatType
save_everystep::Bool
saveiter::Int
end

function (affect!::SavingAffect)(integrator)
# see OrdinaryDiffEq.jl -> integrator_utils.jl, function savevalues!
while !isempty(affect!.saveat) && integrator.tdir*top(affect!.saveat) <= integrator.tdir*integrator.t # Perform saveat
affect!.saveiter += 1
curt = pop!(affect!.saveat) # current time
if curt != integrator.t # If <t, interpolate
ode_addsteps!(integrator)
Θ = (curt - integrator.tprev)/integrator.dt
curu = ode_interpolant(Θ, integrator, nothing, Val{0}) # out of place, but no force copy later
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, curt)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(curt, curu, integrator))
else # ==t, just save
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.t, integrator.u, integrator))
end
end
if affect!.save_everystep
affect!.saveiter += 1
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t)
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.t, integrator.u, integrator))
end
u_modified!(integrator, false)
end

function saving_initialize(cb, t, u, integrator)
cb.affect!(integrator)
end


"""
SavingCallback(save_func, saved_values::SavedValues;
saveat=Vector{eltype(saved_values.t)}(),
save_everystep=isempty(saveat),
tdir=1)
A `DiscreteCallback` applied after every step, saving the time `t` and the value
of `save_func(t, u, integrator)` in `saved_values`.
If `save_everystep`, every step of the integrator is saved.
If `saveat` is specified, the values are saved at the given times, using
interpolation if necessary.
If the time `tdir` direction is not positive, i.e. `tspan[1] > tspan[2]`,
`tdir = -1` has to be specified.
"""
function SavingCallback(save_func, saved_values::SavedValues;
saveat=Vector{eltype(saved_values.t)}(),
save_everystep=isempty(saveat),
tdir=1)
# saveat conversions, see OrdinaryDiffEq.jl -> integrators/type.jl
saveat_vec = collect(saveat)
if tdir > 0
saveat_internal = binary_minheap(saveat_vec)
else
saveat_internal = binary_maxheap(saveat_vec)
end
affect! = SavingAffect(save_func, saved_values, saveat_internal, save_everystep, 0)
condtion = (t, u, integrator) -> true
DiscreteCallback(condtion, affect!;
initialize = saving_initialize,
save_positions=(false,false))
end


export SavingCallback, SavedValues
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ tic()
@time @testset "Domain tests" begin include("domain_tests.jl") end
@time @testset "Manifold tests" begin include("manifold_tests.jl") end
@time @testset "StepsizeLimiter tests" begin include("stepsizelimiter_tests.jl") end
@time @testset "Saving tests" begin include("saving_tests.jl") end
toc()
72 changes: 72 additions & 0 deletions test/saving_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using Base.Test, OrdinaryDiffEq, DiffEqProblemLibrary, DiffEqCallbacks

# save_everystep, scalar problem
prob = prob_ode_linear
saved_values = SavedValues(Float64, Float64)
cb = SavingCallback((t,u,integrator)->u, saved_values)
sol = solve(prob, Tsit5(), callback=cb)
print("\n", saved_values, "\n")
@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> sol.u[idx] == saved_values.saveval[idx], eachindex(saved_values.t))

# save_everystep, inplace problem
prob2D = prob_ode_2Dlinear
saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0))
cb = SavingCallback((t,u,integrator)->copy(u), saved_values)
sol = solve(prob2D, Tsit5(), callback=cb)
@test all(idx -> sol.t[idx] .== saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> all(sol.u[idx] .== saved_values.saveval[idx]), eachindex(saved_values.t))

saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0))
cb = SavingCallback((t,u,integrator)->u[1], saved_values)
sol = solve(prob2D, Tsit5(), callback=cb)
@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> sol.u[idx][1] == saved_values.saveval[idx], eachindex(saved_values.t))

# saveat, scalar problem
saved_values = SavedValues(Float64, Float64)
saveat = linspace(prob.tspan..., 10)
cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat)
sol = solve(prob, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

# saveat, inplace problem
saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0))
saveat = linspace(prob2D.tspan..., 10)
cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat)
sol = solve(prob2D, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0))
saveat = linspace(prob2D.tspan..., 10)
cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat)
sol = solve(prob2D, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

# saveat, tdir<0, scalar problem
prob_inverse = ODEProblem(prob.f, prob.u0, (prob.tspan[end], prob.tspan[1]))
saved_values = SavedValues(Float64, Float64)
saveat = linspace(prob_inverse.tspan..., 10)
cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat, tdir=-1)
sol = solve(prob_inverse, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

# saveat, tdir<0, inplace problem
prob2D_inverse = ODEProblem(prob2D.f, prob2D.u0, (prob2D.tspan[end], prob2D.tspan[1]))
saved_values = SavedValues(eltype(prob2D_inverse.tspan), typeof(prob2D_inverse.u0))
saveat = linspace(prob2D_inverse.tspan..., 10)
cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat, tdir=-1)
sol = solve(prob2D_inverse, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

saved_values = SavedValues(eltype(prob2D_inverse.tspan), eltype(prob2D_inverse.u0))
saveat = linspace(prob2D_inverse.tspan..., 10)
cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat, tdir=-1)
sol = solve(prob2D_inverse, Tsit5(), callback=cb)
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t))
@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t))

0 comments on commit f962c26

Please sign in to comment.