Skip to content

Commit

Permalink
Merge pull request #14 from devmotion/domain
Browse files Browse the repository at this point in the history
Domain callbacks
  • Loading branch information
ChrisRackauckas committed Aug 18, 2017
2 parents 9b8e579 + a40b08a commit 09ae398
Show file tree
Hide file tree
Showing 6 changed files with 356 additions and 15 deletions.
3 changes: 3 additions & 0 deletions src/DiffEqCallbacks.jl
Expand Up @@ -5,7 +5,10 @@ module DiffEqCallbacks
using DiffEqBase, NLsolve, ForwardDiff
import DiffBase

import OrdinaryDiffEq: fix_dt_at_bounds!, modify_dt_for_tstops!

include("autoabstol.jl")
include("manifold.jl")
include("domain.jl")

end # module
197 changes: 197 additions & 0 deletions src/domain.jl
@@ -0,0 +1,197 @@
# Keep ODE solution in a domain specified by a function. Inspired by:
# Shampine, L.F., S. Thompson, J.A. Kierzenka, and G.D. Byrne, "Non-negative solutions
# of ODEs," Applied Mathematics and Computation Vol. 170, 2005, pp. 556-569.

# type definitions

abstract type AbstractDomainAffect{T,S,uType} end

struct PositiveDomainAffect{T,S,uType} <: AbstractDomainAffect{T,S,uType}
abstol::T
scalefactor::S
u::uType
end

struct GeneralDomainAffect{autonomous,F,T,S,uType} <: AbstractDomainAffect{T,S,uType}
g::F
abstol::T
scalefactor::S
u::uType
resid::uType

function GeneralDomainAffect{autonomous}(g::F, abstol::T, scalefactor::S, u::uType,
resid::uType) where {autonomous,F,T,S,uType}
new{autonomous,F,T,S,uType}(g, abstol, scalefactor, u, resid)
end
end

# definitions of callback functions

# Workaround since it is not possible to add methods to an abstract type:
# https://github.com/JuliaLang/julia/issues/14919
(f::PositiveDomainAffect)(integrator) = affect!(integrator, f)
(f::GeneralDomainAffect)(integrator) = affect!(integrator, f)

# general method defintions for domain callbacks

"""
affect!(integrator, f::AbstractDomainAffect)
Apply domain callback `f` to `integrator`.
"""
function affect!(integrator, f::AbstractDomainAffect{T,S,uType}) where {T,S,uType}
# modify u
u_modified!(integrator, modify_u!(integrator, f))

# define array of next time step, absolute tolerance, and scale factor
u = uType <: Void ? similar(integrator.u) : f.u
abstol = T <: Void ? integrator.opts.abstol : f.abstol
scalefactor = S <: Void ? 1//2 : f.scalefactor

# setup callback and save addtional arguments for checking next time step
args = setup(f, integrator)

# cache current time step
dt = integrator.dt
dt_modified = false

# update time step of integrator to proposed next time step
integrator.dt = get_proposed_dt(integrator)

# adjust time step to bounds and time stops
fix_dt_at_bounds!(integrator)
modify_dt_for_tstops!(integrator)
t = integrator.t + integrator.dt

while integrator.tdir * integrator.dt > 0
# calculate estimated value of next step and its residuals
integrator(u, t)

# check whether time step is accepted
isaccepted(t, u, abstol, f, args...) && break

# reduce time step
dtcache = integrator.dt
integrator.dt *= scalefactor
dt_modified = true

# adjust new time step to bounds and time stops
fix_dt_at_bounds!(integrator)
modify_dt_for_tstops!(integrator)
t = integrator.t + integrator.dt

# abort iteration when time step is not changed
if dtcache == integrator.dt
if integrator.opts.verbose
warn("Could not restrict values to domain. Iteration was canceled since ",
"time step dt = ", integrator.dt, " could not be reduced.")
end
break
end
end

# update current and next time step
if dt_modified # add safety factor since guess is based on extrapolation
set_proposed_dt!(integrator, 9//10*integrator.dt)
else
set_proposed_dt!(integrator, integrator.dt)
end
integrator.dt = dt
end

"""
modify_u!(integrator, f::AbstractDomainAffect)
Modify current state vector `u` of `integrator` if required, and return whether it actually
was modified.
"""
modify_u!(integrator, ::AbstractDomainAffect) = false

"""
setup(f::AbstractDomainAffect, integrator)
Setup callback `f` and return an arbitrary tuple whose elements are used as additional
arguments in checking whether time step is accepted.
"""
setup(::AbstractDomainAffect, integrator) = ()

"""
isaccepted(u, abstol, f::AbstractDomainAffect, args...)
Return whether `u` is an acceptable state vector at the next time point given absolute
tolerance `abstol`, callback `f`, and other optional arguments.
"""
isaccepted(t, u, tolerance, ::AbstractDomainAffect, args...) = true

# specific method definitions for positive domain callback

function modify_u!(integrator, f::PositiveDomainAffect)
modified = false

# set all negative values to zero
@inbounds for i in eachindex(integrator.u)
if integrator.u[i] < 0
integrator.u[i] = 0
modified = true
end
end

modified
end

# state vector is accepted if its entries are greater than -abstol
isaccepted(t, u, abstol::Number, ::PositiveDomainAffect) = all(x -> x + abstol > 0, u)
isaccepted(t, u, abstol, ::PositiveDomainAffect) = all(x + y > 0 for (x,y) in zip(u, abstol))

# specific method definitions for general domain callback

# create array of residuals
setup(f::GeneralDomainAffect, integrator) =
typeof(f.resid) <: Void ? (similar(integrator.u),) : (f.resid,)

function isaccepted(t, u, abstol, f::GeneralDomainAffect{autonomous,F,T,S,uType},
resid) where {autonomous,F,T,S,uType}
# calculate residuals
if autonomous
f.g(u, resid)
else
f.g(t, u, resid)
end

# accept time step if residuals are smaller than the tolerance
if typeof(abstol) <: Number
all(x-> x < abstol, resid)
else
# element-wise comparison
all(x < y for (x,y) in zip(resid, abstol))
end
end

# callback definitions

function GeneralDomain(g, u=nothing; nlsolve=NLSOLVEJL_SETUP(), save=true,
abstol=nothing, scalefactor=nothing, autonomous=numargs(g)==2,
nlopts=Dict(:ftol => 10*eps()))
if typeof(u) <: Void
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, nothing, nothing)
else
affect! = GeneralDomainAffect{autonomous}(g, abstol, scalefactor, deepcopy(u),
deepcopy(u))
end
condition = (t,u,integrator) -> true
CallbackSet(ManifoldProjection(g; nlsolve=nlsolve, save=false,
autonomous=autonomous, nlopts=nlopts),
DiscreteCallback(condition, affect!; save_positions=(false, save)))
end

function PositiveDomain(u=nothing; save=true, abstol=nothing, scalefactor=nothing)
if typeof(u) <: Void
affect! = PositiveDomainAffect(abstol, scalefactor, nothing)
else
affect! = PositiveDomainAffect(abstol, scalefactor, deepcopy(u))
end
condition = (t,u,integrator) -> true
DiscreteCallback(condition, affect!; save_positions=(false, save))
end

export GeneralDomain, PositiveDomain
52 changes: 39 additions & 13 deletions src/manifold.jl
Expand Up @@ -24,16 +24,18 @@ function autodiff_setup{CS}(f!, initial_x,chunk_size::Type{Val{CS}})
DiffBase.value(jac_res)
end

return DifferentiableMultivariateFunction(f!, g!, fg!)
return DifferentiableMultivariateFunction((x,resid)->f!(reshape(x,size(initial_x)...),
resid),
g!, fg!)
end

function non_autodiff_setup(f!, initial_x)
DifferentiableMultivariateFunction((resid,x)->f!(resid,reshape(x,size(initial_x)...)))
DifferentiableMultivariateFunction((x,resid)->f!(reshape(x,size(initial_x)...), resid))
end

immutable NLSOLVEJL_SETUP{CS,AD} end
Base.@pure NLSOLVEJL_SETUP(;chunk_size=0,autodiff=true) = NLSOLVEJL_SETUP{chunk_size,autodiff}()
(p::NLSOLVEJL_SETUP)(f,u0) = (res=NLsolve.nlsolve(f,u0); res.zero)
(p::NLSOLVEJL_SETUP)(f, u0; kwargs...) = (res=NLsolve.nlsolve(f, u0; kwargs...); res.zero)
function (p::NLSOLVEJL_SETUP{CS,AD}){CS,AD}(::Type{Val{:init}},f,u0_prototype)
if AD
return non_autodiff_setup(f,u0_prototype)
Expand All @@ -47,28 +49,52 @@ get_chunksize{CS,AD}(x::NLSOLVEJL_SETUP{CS,AD}) = CS

#########################

type ManifoldProjection{NL}
# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{F}
f::F
t
end
(p::NonAutonomousFunction)(u, res) = p.f(p.t, u, res)

mutable struct ManifoldProjection{autonomous,F,NL,O}
g::F
nl_rhs
nlsolve::NL
nlopts::O

function ManifoldProjection{autonomous}(g, nlsolve, nlopts) where {autonomous}
# replace residual function if it is time-dependent
# since NLsolve only accepts functions with two arguments
if !autonomous
g = NonAutonomousFunction(g, 0)
end

new{autonomous,typeof(g),typeof(nlsolve),typeof(nlopts)}(g, g, nlsolve, nlopts)
end
end

# Now make `affect!` for this:
function (p::ManifoldProjection)(integrator)
nlres = reshape(p.nlsolve(p.nl_rhs,vec(integrator.u)),size(integrator.u)...)::typeof(integrator.u)
function (p::ManifoldProjection{autonomous,NL})(integrator) where {autonomous,NL}
# update current time if residual function is time-dependent
if !autonomous
p.g.t = integrator.t
end

nlres = reshape(p.nlsolve(p.nl_rhs, vec(integrator.u); p.nlopts...),
size(integrator.u)...)::typeof(integrator.u)
integrator.u .= nlres
end

function Manifold_initialize(cb,t,u,integrator)
cb.affect!.nl_rhs = cb.affect!.nlsolve(
Val{:init},
cb.affect!.nl_rhs,
integrator.u)
cb.affect!.nl_rhs = cb.affect!.nlsolve(Val{:init}, cb.affect!.g, u)
end

function ManifoldProjection(g;nlsolve=NLSOLVEJL_SETUP(),save=true)
affect! = ManifoldProjection(g,nlsolve)
function ManifoldProjection(g; nlsolve=NLSOLVEJL_SETUP(), save=true,
autonomous=numargs(g)==2, nlopts=Dict{Symbol,Any}())
affect! = ManifoldProjection{autonomous}(g, nlsolve, nlopts)
condtion = (t,u,integrator) -> true
save_positions = (false,save)
DiscreteCallback(condtion,affect!;
DiscreteCallback(condtion, affect!;
initialize = Manifold_initialize,
save_positions=save_positions)
end
Expand Down
101 changes: 101 additions & 0 deletions test/domain_tests.jl
@@ -0,0 +1,101 @@
using DiffEqCallbacks, OrdinaryDiffEq, Base.Test

# Non-negative ODE examples
#
# Reference:
# Shampine, L.F., S. Thompson, J.A. Kierzenka, and G.D. Byrne,
# "Non-negative solutions of ODEs," Applied Mathematics and Computation Vol. 170, 2005,
# pp. 556-569.
# https://www.mathworks.com/help/matlab/math/nonnegative-ode-solution.html

"""
Absolute value function
```math
\\frac{du}{dt} = -|u|
```
with initial condition ``u₀=1``, and solution
```math
u(t) = u₀*e^{-t}
```
for positive initial values ``u₀``.
"""
function absval(t,u,du)
du[1] = -abs(u[1])
end
(f::typeof(absval))(::Type{Val{:analytic}}, t, u₀) = u₀*exp(-t)
prob_absval = ODEProblem(absval, [1.0], (0.0, 40.0))

# naive approach leads to large errors
naive_sol_absval = solve(prob_absval, BS3())
@test naive_sol_absval.errors[:l∞] > 9e4
@test naive_sol_absval.errors[:l2] > 1.3e4

# general domain approach
# can only guarantee approximately non-negative values
function g(u,resid)
resid[1] = u[1] < 0 ? -u[1] : 0
end

general_sol_absval = solve(prob_absval, BS3(); callback=GeneralDomain(g, [1.0]))
@test all(x -> x[1] -10*eps(), general_sol_absval.u)
@test general_sol_absval.errors[:l∞] < 9.9e-5
@test general_sol_absval.errors[:l2] < 4.3e-5
@test general_sol_absval.errors[:final] < 4.4e-18

# test non-autonomous function
g_t(t, u, resid) = g(u, resid)

general_t_sol_absval = solve(prob_absval, BS3(); callback=GeneralDomain(g_t, [1.0]))
@test general_sol_absval.t == general_t_sol_absval.t &&
general_sol_absval.u == general_t_sol_absval.u

# positive domain approach
# can guarantee non-negative values
positive_sol_absval = solve(prob_absval, BS3(); callback=PositiveDomain([1.0]))
@test all(x -> x[1] 0, positive_sol_absval.u)
@test positive_sol_absval.errors[:l∞] < 9.9e-5
@test positive_sol_absval.errors[:l2] < 4.3e-5
@test positive_sol_absval.errors[:final] < 4.3e-18 # slightly better than general approach
@test general_sol_absval.t == positive_sol_absval.t

# specify abstol as array or scalar
positive_sol_absval2 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], abstol=[1e-6]))
@test positive_sol_absval.t == positive_sol_absval2.t &&
positive_sol_absval.u == positive_sol_absval2.u
positive_sol_absval3 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], abstol=1e-6))
@test positive_sol_absval.t == positive_sol_absval3.t &&
positive_sol_absval.u == positive_sol_absval3.u

# specify scalefactor
positive_sol_absval4 = solve(prob_absval, BS3(); callback=PositiveDomain([1.0], scalefactor=0.2))
@test length(positive_sol_absval.t) < length(positive_sol_absval4.t)
@test positive_sol_absval.errors[:l2] > positive_sol_absval4.errors[:l2]

"""
Knee problem
```math
\\frac{du}{dt} = \epsilon^{-1}(1-t-u)u
```
with initial condition ``u0=1``, and generally ``0 < \epsilon << 1``.
Here ``\epsilon=1e-6``. Then the solution approaches ``u=1-t`` for ``t<1``
and ``u=0`` for ``t>1``.
"""
function knee(t,u,du)
du[1] = 1e6*(1-t-u[1])*u[1]
end

prob_knee = ODEProblem(knee, [1.0], (0.0, 2.0))

# unfortunately callbacks do not work with solver CVODE_BDF which is comparable to ode15s
# used in MATLAB example, so we use Rodas5
naive_sol_knee = solve(prob_knee, Rodas5())
@test naive_sol_knee[1, end] -1.0 atol=1e-5

# positive domain approach
positive_sol_knee = solve(prob_knee, Rodas5(); callback=PositiveDomain([1.0]))
@test all(x -> x[1] 0, positive_sol_knee.u)
@test positive_sol_knee[1, end] 0.0

0 comments on commit 09ae398

Please sign in to comment.