Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from devmotion/domain
Domain callbacks
- Loading branch information
Showing
6 changed files
with
356 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# Keep ODE solution in a domain specified by a function. Inspired by: | ||
# Shampine, L.F., S. Thompson, J.A. Kierzenka, and G.D. Byrne, "Non-negative solutions | ||
# of ODEs," Applied Mathematics and Computation Vol. 170, 2005, pp. 556-569. | ||
|
||
# type definitions | ||
|
||
abstract type AbstractDomainAffect{T,S,uType} end | ||
|
||
struct PositiveDomainAffect{T,S,uType} <: AbstractDomainAffect{T,S,uType} | ||
abstol::T | ||
scalefactor::S | ||
u::uType | ||
end | ||
|
||
struct GeneralDomainAffect{autonomous,F,T,S,uType} <: AbstractDomainAffect{T,S,uType} | ||
g::F | ||
abstol::T | ||
scalefactor::S | ||
u::uType | ||
resid::uType | ||
|
||
function GeneralDomainAffect{autonomous}(g::F, abstol::T, scalefactor::S, u::uType, | ||
resid::uType) where {autonomous,F,T,S,uType} | ||
new{autonomous,F,T,S,uType}(g, abstol, scalefactor, u, resid) | ||
end | ||
end | ||
|
||
# definitions of callback functions | ||
|
||
# Workaround since it is not possible to add methods to an abstract type: | ||
# https://github.com/JuliaLang/julia/issues/14919 | ||
(f::PositiveDomainAffect)(integrator) = affect!(integrator, f) | ||
(f::GeneralDomainAffect)(integrator) = affect!(integrator, f) | ||
|
||
# general method defintions for domain callbacks | ||
|
||
""" | ||
affect!(integrator, f::AbstractDomainAffect) | ||
Apply domain callback `f` to `integrator`. | ||
""" | ||
function affect!(integrator, f::AbstractDomainAffect{T,S,uType}) where {T,S,uType} | ||
# modify u | ||
u_modified!(integrator, modify_u!(integrator, f)) | ||
|
||
# define array of next time step, absolute tolerance, and scale factor | ||
u = uType <: Void ? similar(integrator.u) : f.u | ||
abstol = T <: Void ? integrator.opts.abstol : f.abstol | ||
scalefactor = S <: Void ? 1//2 : f.scalefactor | ||
|
||
# setup callback and save addtional arguments for checking next time step | ||
args = setup(f, integrator) | ||
|
||
# cache current time step | ||
dt = integrator.dt | ||
dt_modified = false | ||
|
||
# update time step of integrator to proposed next time step | ||
integrator.dt = get_proposed_dt(integrator) | ||
|
||
# adjust time step to bounds and time stops | ||
fix_dt_at_bounds!(integrator) | ||
modify_dt_for_tstops!(integrator) | ||
t = integrator.t + integrator.dt | ||
|
||
while integrator.tdir * integrator.dt > 0 | ||
# calculate estimated value of next step and its residuals | ||
integrator(u, t) | ||
|
||
# check whether time step is accepted | ||
isaccepted(t, u, abstol, f, args...) && break | ||
|
||
# reduce time step | ||
dtcache = integrator.dt | ||
integrator.dt *= scalefactor | ||
dt_modified = true | ||
|
||
# adjust new time step to bounds and time stops | ||
fix_dt_at_bounds!(integrator) | ||
modify_dt_for_tstops!(integrator) | ||
t = integrator.t + integrator.dt | ||
|
||
# abort iteration when time step is not changed | ||
if dtcache == integrator.dt | ||
if integrator.opts.verbose | ||
warn("Could not restrict values to domain. Iteration was canceled since ", | ||
"time step dt = ", integrator.dt, " could not be reduced.") | ||
end | ||
break | ||
end | ||
end | ||
|
||
# update current and next time step | ||
if dt_modified # add safety factor since guess is based on extrapolation | ||
set_proposed_dt!(integrator, 9//10*integrator.dt) | ||
else | ||
set_proposed_dt!(integrator, integrator.dt) | ||
end | ||
integrator.dt = dt | ||
end | ||
|
||
""" | ||
modify_u!(integrator, f::AbstractDomainAffect) | ||
Modify current state vector `u` of `integrator` if required, and return whether it actually | ||
was modified. | ||
""" | ||
modify_u!(integrator, ::AbstractDomainAffect) = false | ||
|
||
""" | ||
setup(f::AbstractDomainAffect, integrator) | ||
Setup callback `f` and return an arbitrary tuple whose elements are used as additional | ||
arguments in checking whether time step is accepted. | ||
""" | ||
setup(::AbstractDomainAffect, integrator) = () | ||
|
||
""" | ||
isaccepted(u, abstol, f::AbstractDomainAffect, args...) | ||
Return whether `u` is an acceptable state vector at the next time point given absolute | ||
tolerance `abstol`, callback `f`, and other optional arguments. | ||
""" | ||
isaccepted(t, u, tolerance, ::AbstractDomainAffect, args...) = true | ||
|
||
# specific method definitions for positive domain callback | ||
|
||
function modify_u!(integrator, f::PositiveDomainAffect) | ||
modified = false | ||
|
||
# set all negative values to zero | ||
@inbounds for i in eachindex(integrator.u) | ||
if integrator.u[i] < 0 | ||
integrator.u[i] = 0 | ||
modified = true | ||
end | ||
end | ||
|
||
modified | ||
end | ||
|
||
# state vector is accepted if its entries are greater than -abstol | ||
isaccepted(t, u, abstol::Number, ::PositiveDomainAffect) = all(x -> x + abstol > 0, u) | ||
isaccepted(t, u, abstol, ::PositiveDomainAffect) = all(x + y > 0 for (x,y) in zip(u, abstol)) | ||
|
||
# specific method definitions for general domain callback | ||
|
||
# create array of residuals | ||
setup(f::GeneralDomainAffect, integrator) = | ||
typeof(f.resid) <: Void ? (similar(integrator.u),) : (f.resid,) | ||
|
||
function isaccepted(t, u, abstol, f::GeneralDomainAffect{autonomous,F,T,S,uType}, | ||
resid) where {autonomous,F,T,S,uType} | ||
# calculate residuals | ||
if autonomous | ||
f.g(u, resid) | ||
else | ||
f.g(t, u, resid) | ||
end | ||
|
||
# accept time step if residuals are smaller than the tolerance | ||
if typeof(abstol) <: Number | ||
all(x-> x < abstol, resid) | ||
else | ||
# element-wise comparison | ||
all(x < y for (x,y) in zip(resid, abstol)) | ||
end | ||
end | ||
|
||
# callback definitions | ||
|
||
function GeneralDomain(g, u=nothing; nlsolve=NLSOLVEJL_SETUP(), save=true, | ||
abstol=nothing, scalefactor=nothing, autonomous=numargs(g)==2, | ||
nlopts=Dict(:ftol => 10*eps())) | ||
if typeof(u) <: Void | ||
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, nothing, nothing) | ||
else | ||
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, deepcopy(u), | ||
deepcopy(u)) | ||
end | ||
condition = (t,u,integrator) -> true | ||
CallbackSet(ManifoldProjection(g; nlsolve=nlsolve, save=false, | ||
autonomous=autonomous, nlopts=nlopts), | ||
DiscreteCallback(condition, affect!; save_positions=(false, save))) | ||
end | ||
|
||
function PositiveDomain(u=nothing; save=true, abstol=nothing, scalefactor=nothing) | ||
if typeof(u) <: Void | ||
affect! = PositiveDomainAffect(abstol, scalefactor, nothing) | ||
else | ||
affect! = PositiveDomainAffect(abstol, scalefactor, deepcopy(u)) | ||
end | ||
condition = (t,u,integrator) -> true | ||
DiscreteCallback(condition, affect!; save_positions=(false, save)) | ||
end | ||
|
||
export GeneralDomain, PositiveDomain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
using DiffEqCallbacks, OrdinaryDiffEq, Base.Test | ||
|
||
# Non-negative ODE examples | ||
# | ||
# Reference: | ||
# Shampine, L.F., S. Thompson, J.A. Kierzenka, and G.D. Byrne, | ||
# "Non-negative solutions of ODEs," Applied Mathematics and Computation Vol. 170, 2005, | ||
# pp. 556-569. | ||
# https://www.mathworks.com/help/matlab/math/nonnegative-ode-solution.html | ||
|
||
""" | ||
Absolute value function | ||
```math | ||
\\frac{du}{dt} = -|u| | ||
``` | ||
with initial condition ``u₀=1``, and solution | ||
```math | ||
u(t) = u₀*e^{-t} | ||
``` | ||
for positive initial values ``u₀``. | ||
""" | ||
function absval(t,u,du) | ||
du[1] = -abs(u[1]) | ||
end | ||
(f::typeof(absval))(::Type{Val{:analytic}}, t, u₀) = u₀*exp(-t) | ||
prob_absval = ODEProblem(absval, [1.0], (0.0, 40.0)) | ||
|
||
# naive approach leads to large errors | ||
naive_sol_absval = solve(prob_absval, BS3()) | ||
@test naive_sol_absval.errors[:l∞] > 9e4 | ||
@test naive_sol_absval.errors[:l2] > 1.3e4 | ||
|
||
# general domain approach | ||
# can only guarantee approximately non-negative values | ||
function g(u,resid) | ||
resid[1] = u[1] < 0 ? -u[1] : 0 | ||
end | ||
|
||
general_sol_absval = solve(prob_absval, BS3(); callback=GeneralDomain(g, [1.0])) | ||
@test all(x -> x[1] ≥ -10*eps(), general_sol_absval.u) | ||
@test general_sol_absval.errors[:l∞] < 9.9e-5 | ||
@test general_sol_absval.errors[:l2] < 4.3e-5 | ||
@test general_sol_absval.errors[:final] < 4.4e-18 | ||
|
||
# test non-autonomous function | ||
g_t(t, u, resid) = g(u, resid) | ||
|
||
general_t_sol_absval = solve(prob_absval, BS3(); callback=GeneralDomain(g_t, [1.0])) | ||
@test general_sol_absval.t == general_t_sol_absval.t && | ||
general_sol_absval.u == general_t_sol_absval.u | ||
|
||
# positive domain approach | ||
# can guarantee non-negative values | ||
positive_sol_absval = solve(prob_absval, BS3(); callback=PositiveDomain([1.0])) | ||
@test all(x -> x[1] ≥ 0, positive_sol_absval.u) | ||
@test positive_sol_absval.errors[:l∞] < 9.9e-5 | ||
@test positive_sol_absval.errors[:l2] < 4.3e-5 | ||
@test positive_sol_absval.errors[:final] < 4.3e-18 # slightly better than general approach | ||
@test general_sol_absval.t == positive_sol_absval.t | ||
|
||
# specify abstol as array or scalar | ||
positive_sol_absval2 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], abstol=[1e-6])) | ||
@test positive_sol_absval.t == positive_sol_absval2.t && | ||
positive_sol_absval.u == positive_sol_absval2.u | ||
positive_sol_absval3 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], abstol=1e-6)) | ||
@test positive_sol_absval.t == positive_sol_absval3.t && | ||
positive_sol_absval.u == positive_sol_absval3.u | ||
|
||
# specify scalefactor | ||
positive_sol_absval4 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], scalefactor=0.2)) | ||
@test length(positive_sol_absval.t) < length(positive_sol_absval4.t) | ||
@test positive_sol_absval.errors[:l2] > positive_sol_absval4.errors[:l2] | ||
|
||
""" | ||
Knee problem | ||
```math | ||
\\frac{du}{dt} = \epsilon^{-1}(1-t-u)u | ||
``` | ||
with initial condition ``u0=1``, and generally ``0 < \epsilon << 1``. | ||
Here ``\epsilon=1e-6``. Then the solution approaches ``u=1-t`` for ``t<1`` | ||
and ``u=0`` for ``t>1``. | ||
""" | ||
function knee(t,u,du) | ||
du[1] = 1e6*(1-t-u[1])*u[1] | ||
end | ||
|
||
prob_knee = ODEProblem(knee, [1.0], (0.0, 2.0)) | ||
|
||
# unfortunately callbacks do not work with solver CVODE_BDF which is comparable to ode15s | ||
# used in MATLAB example, so we use Rodas5 | ||
naive_sol_knee = solve(prob_knee, Rodas5()) | ||
@test naive_sol_knee[1, end] ≈ -1.0 atol=1e-5 | ||
|
||
# positive domain approach | ||
positive_sol_knee = solve(prob_knee, Rodas5(); callback=PositiveDomain([1.0])) | ||
@test all(x -> x[1] ≥ 0, positive_sol_knee.u) | ||
@test positive_sol_knee[1, end] ≈ 0.0 |
Oops, something went wrong.