-
-
Notifications
You must be signed in to change notification settings - Fork 44
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #18 from ranocha/pull-request/5a0b3952
WIP: SavingCallback
- Loading branch information
Showing
7 changed files
with
200 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
julia 0.6.0 | ||
DiffEqBase 0.6.0 | ||
NLsolve | ||
ForwardDiff 0.5.0 | ||
ForwardDiff 0.5 0.6- | ||
DiffBase | ||
OrdinaryDiffEq | ||
RecursiveArrayTools | ||
DataStructures |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
""" | ||
SavedValues{tType<:Real, savevalType} | ||
A struct used to save values of the time in `t::Vector{tType}` and | ||
additional values in `saveval::Vector{savevalType}`. | ||
""" | ||
struct SavedValues{tType<:Real, savevalType} | ||
t::Vector{tType} | ||
saveval::Vector{savevalType} | ||
end | ||
|
||
""" | ||
SavedValues(tType::DataType, savevalType::DataType) | ||
Return `SavedValues{tType, savevalType}` with empty storage vectors. | ||
""" | ||
function SavedValues(tType::DataType, savevalType::DataType) | ||
SavedValues{tType, savevalType}(Vector{tType}(), Vector{savevalType}()) | ||
end | ||
|
||
function Base.show(io::IO, saved_values::SavedValues) | ||
tType = eltype(saved_values.t) | ||
savevalType = eltype(saved_values.saveval) | ||
print(io, "SavedValues{tType=", tType, ", savevalType=", savevalType, "}", | ||
"\nt:\n", saved_values.t, "\nsaveval:\n", saved_values.saveval) | ||
end | ||
|
||
|
||
mutable struct SavingAffect{SaveFunc, tType, savevalType, saveatType} | ||
save_func::SaveFunc | ||
saved_values::SavedValues{tType, savevalType} | ||
saveat::saveatType | ||
save_everystep::Bool | ||
saveiter::Int | ||
end | ||
|
||
function (affect!::SavingAffect)(integrator) | ||
# see OrdinaryDiffEq.jl -> integrator_utils.jl, function savevalues! | ||
while !isempty(affect!.saveat) && integrator.tdir*top(affect!.saveat) <= integrator.tdir*integrator.t # Perform saveat | ||
affect!.saveiter += 1 | ||
curt = pop!(affect!.saveat) # current time | ||
if curt != integrator.t # If <t, interpolate | ||
ode_addsteps!(integrator) | ||
Θ = (curt - integrator.tprev)/integrator.dt | ||
curu = ode_interpolant(Θ, integrator, nothing, Val{0}) # out of place, but no force copy later | ||
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, curt) | ||
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(curt, curu, integrator)) | ||
else # ==t, just save | ||
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t) | ||
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.t, integrator.u, integrator)) | ||
end | ||
end | ||
if affect!.save_everystep | ||
affect!.saveiter += 1 | ||
copyat_or_push!(affect!.saved_values.t, affect!.saveiter, integrator.t) | ||
copyat_or_push!(affect!.saved_values.saveval, affect!.saveiter, affect!.save_func(integrator.t, integrator.u, integrator)) | ||
end | ||
u_modified!(integrator, false) | ||
end | ||
|
||
function saving_initialize(cb, t, u, integrator) | ||
cb.affect!(integrator) | ||
end | ||
|
||
|
||
""" | ||
SavingCallback(save_func, saved_values::SavedValues; | ||
saveat=Vector{eltype(saved_values.t)}(), | ||
save_everystep=isempty(saveat), | ||
tdir=1) | ||
A `DiscreteCallback` applied after every step, saving the time `t` and the value | ||
of `save_func(t, u, integrator)` in `saved_values`. | ||
If `save_everystep`, every step of the integrator is saved. | ||
If `saveat` is specified, the values are saved at the given times, using | ||
interpolation if necessary. | ||
If the time `tdir` direction is not positive, i.e. `tspan[1] > tspan[2]`, | ||
`tdir = -1` has to be specified. | ||
""" | ||
function SavingCallback(save_func, saved_values::SavedValues; | ||
saveat=Vector{eltype(saved_values.t)}(), | ||
save_everystep=isempty(saveat), | ||
tdir=1) | ||
# saveat conversions, see OrdinaryDiffEq.jl -> integrators/type.jl | ||
saveat_vec = collect(saveat) | ||
if tdir > 0 | ||
saveat_internal = binary_minheap(saveat_vec) | ||
else | ||
saveat_internal = binary_maxheap(saveat_vec) | ||
end | ||
affect! = SavingAffect(save_func, saved_values, saveat_internal, save_everystep, 0) | ||
condtion = (t, u, integrator) -> true | ||
DiscreteCallback(condtion, affect!; | ||
initialize = saving_initialize, | ||
save_positions=(false,false)) | ||
end | ||
|
||
|
||
export SavingCallback, SavedValues |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
using Base.Test, OrdinaryDiffEq, DiffEqProblemLibrary, DiffEqCallbacks | ||
|
||
# save_everystep, scalar problem | ||
prob = prob_ode_linear | ||
saved_values = SavedValues(Float64, Float64) | ||
cb = SavingCallback((t,u,integrator)->u, saved_values) | ||
sol = solve(prob, Tsit5(), callback=cb) | ||
print("\n", saved_values, "\n") | ||
@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> sol.u[idx] == saved_values.saveval[idx], eachindex(saved_values.t)) | ||
|
||
# save_everystep, inplace problem | ||
prob2D = prob_ode_2Dlinear | ||
saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0)) | ||
cb = SavingCallback((t,u,integrator)->copy(u), saved_values) | ||
sol = solve(prob2D, Tsit5(), callback=cb) | ||
@test all(idx -> sol.t[idx] .== saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> all(sol.u[idx] .== saved_values.saveval[idx]), eachindex(saved_values.t)) | ||
|
||
saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0)) | ||
cb = SavingCallback((t,u,integrator)->u[1], saved_values) | ||
sol = solve(prob2D, Tsit5(), callback=cb) | ||
@test all(idx -> sol.t[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> sol.u[idx][1] == saved_values.saveval[idx], eachindex(saved_values.t)) | ||
|
||
# saveat, scalar problem | ||
saved_values = SavedValues(Float64, Float64) | ||
saveat = linspace(prob.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat) | ||
sol = solve(prob, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) | ||
|
||
# saveat, inplace problem | ||
saved_values = SavedValues(eltype(prob2D.tspan), typeof(prob2D.u0)) | ||
saveat = linspace(prob2D.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat) | ||
sol = solve(prob2D, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) | ||
|
||
saved_values = SavedValues(eltype(prob2D.tspan), eltype(prob2D.u0)) | ||
saveat = linspace(prob2D.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat) | ||
sol = solve(prob2D, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) | ||
|
||
# saveat, tdir<0, scalar problem | ||
prob_inverse = ODEProblem(prob.f, prob.u0, (prob.tspan[end], prob.tspan[1])) | ||
saved_values = SavedValues(Float64, Float64) | ||
saveat = linspace(prob_inverse.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->u, saved_values, saveat=saveat, tdir=-1) | ||
sol = solve(prob_inverse, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> abs(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) | ||
|
||
# saveat, tdir<0, inplace problem | ||
prob2D_inverse = ODEProblem(prob2D.f, prob2D.u0, (prob2D.tspan[end], prob2D.tspan[1])) | ||
saved_values = SavedValues(eltype(prob2D_inverse.tspan), typeof(prob2D_inverse.u0)) | ||
saveat = linspace(prob2D_inverse.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->copy(u), saved_values, saveat=saveat, tdir=-1) | ||
sol = solve(prob2D_inverse, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> norm(sol(saveat[idx]) - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) | ||
|
||
saved_values = SavedValues(eltype(prob2D_inverse.tspan), eltype(prob2D_inverse.u0)) | ||
saveat = linspace(prob2D_inverse.tspan..., 10) | ||
cb = SavingCallback((t,u,integrator)->u[1], saved_values, saveat=saveat, tdir=-1) | ||
sol = solve(prob2D_inverse, Tsit5(), callback=cb) | ||
@test all(idx -> saveat[idx] == saved_values.t[idx], eachindex(saved_values.t)) | ||
@test all(idx -> abs(sol(saveat[idx])[1] - saved_values.saveval[idx]) < 5.e-15, eachindex(saved_values.t)) |