Skip to content

Commit

Permalink
update to new callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 17, 2018
1 parent cfaae7e commit 0fc57e9
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 20 deletions.
38 changes: 26 additions & 12 deletions src/common_interface/callbacks.jl
Expand Up @@ -2,19 +2,19 @@

# Base Case: Only one callback
function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback)
(find_callback_time(integrator,callback)...,1,1)
(find_callback_time(integrator,callback,1)...,1,1)
end

# Starting Case: Compute on the first callback
function find_first_continuous_callback(integrator, callback::DiffEqBase.AbstractContinuousCallback, args...)
find_first_continuous_callback(integrator,find_callback_time(integrator,callback)...,1,1,args...)
find_first_continuous_callback(integrator,find_callback_time(integrator,callback,1)...,1,1,args...)
end

function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Float64,
event_occured::Bool,idx::Int,counter::Int,
callback2)
counter += 1 # counter is idx for callback2.
tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2)
tmin2,upcrossing2,event_occurred2 = find_callback_time(integrator,callback2,counter)

if event_occurred2 && (tmin2 < tmin || !event_occured)
return tmin2,upcrossing2,true,counter,counter
Expand All @@ -27,7 +27,7 @@ function find_first_continuous_callback(integrator,tmin::Number,upcrossing::Floa
find_first_continuous_callback(integrator,find_first_continuous_callback(integrator,tmin,upcrossing,event_occured,idx,counter,callback2)...,args...)
end

@inline function determine_event_occurance(integrator,callback)
@inline function determine_event_occurance(integrator,callback,counter)
event_occurred = false
Θs = range(typeof(integrator.t)(0), stop=typeof(integrator.t)(1), length=callback.interp_points)
interp_index = 0
Expand All @@ -41,7 +41,7 @@ end
previous_condition = callback.condition(@view(integrator.uprev[callback.idxs]),integrator.tprev,integrator)
end

if integrator.event_last_time && abs(previous_condition) < callback.abstol
if integrator.event_last_time == counter && abs(previous_condition) < 100callback.abstol

# abs(previous_condition) < callback.abstol is for multiple events: only
# condition this on the correct event
Expand All @@ -55,14 +55,14 @@ end
tmp = integrator.tmp

if !(typeof(callback.idxs) <: Number)
integrator(tmp,integrator.tprev+100eps(typeof(integrator.tprev)))
integrator(tmp,integrator.tprev+100eps(integrator.tprev))
callback.idxs == nothing ? _tmp = tmp : _tmp = @view tmp[callback.idxs]
else
_tmp = integrator(integrator.tprev+100eps(typeof(integrator.tprev)))[callback.idxs]
_tmp = integrator(integrator.tprev+100eps(integrator.tprev))[callback.idxs]
end

tmp_condition = callback.condition(_tmp,integrator.tprev +
100eps(typeof(integrator.tprev)),
100eps(integrator.tprev),
integrator)

prev_sign = sign((tmp_condition-previous_condition)/integrator.tdir)
Expand Down Expand Up @@ -105,8 +105,8 @@ end
event_occurred,interp_index,Θs,prev_sign,prev_sign_index
end

function find_callback_time(integrator,callback)
event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback)
function find_callback_time(integrator,callback,counter)
event_occurred,interp_index,Θs,prev_sign,prev_sign_index = determine_event_occurance(integrator,callback,counter)
dt = integrator.t - integrator.tprev
if event_occurred
if typeof(callback.condition) <: Nothing
Expand Down Expand Up @@ -135,9 +135,23 @@ function find_callback_time(integrator,callback)
if zero_func(top_Θ) == 0
Θ = top_Θ
else
Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/10))
if integrator.event_last_time == counter &&
abs(zero_func(bottom_θ)) < 100callback.abstol &&
prev_sign_index == 1
# Determined that there is an event by derivative
# But floating point error may make the end point negative
sign_top = sign(zero_func(top_Θ))
bottom_θ += 2eps(typeof(bottom_θ))
iter = 1
while sign(zero_func(bottom_θ)) == sign_top && iter < 12
bottom_θ *= 5
iter += 1
end
iter == 12 && error("Double callback crossing floating pointer reducer errored. Report this issue.")
end
Θ = prevfloat(find_zero(zero_func,(bottom_θ,top_Θ),Roots.AlefeldPotraShi(),atol = callback.abstol/100))
end

#Θ = prevfloat(...)
# prevfloat guerentees that the new time is either 1 floating point
# numbers just before the event or directly at zero, but not after.
Expand Down
6 changes: 3 additions & 3 deletions src/common_interface/integrator_types.jl
Expand Up @@ -36,7 +36,7 @@ mutable struct CVODEIntegrator{uType,pType,memType,solType,algType,fType,UFType,
uprev::tmpType
flag::Cint
just_hit_tstop::Bool
event_last_time::Bool
event_last_time::Int
end

function (integrator::CVODEIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0}) where T
Expand Down Expand Up @@ -72,7 +72,7 @@ mutable struct ARKODEIntegrator{uType,pType,memType,solType,algType,fType,UFType
uprev::tmpType
flag::Cint
just_hit_tstop::Bool
event_last_time::Bool
event_last_time::Int
end

function (integrator::ARKODEIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0}) where T
Expand Down Expand Up @@ -110,7 +110,7 @@ mutable struct IDAIntegrator{uType,duType,pType,memType,solType,algType,fType,UF
uprev::tmpType
flag::Cint
just_hit_tstop::Bool
event_last_time::Bool
event_last_time::Int
end

function (integrator::IDAIntegrator)(t::Number,deriv::Type{Val{T}}=Val{0}) where T
Expand Down
4 changes: 2 additions & 2 deletions src/common_interface/integrator_utils.jl
Expand Up @@ -10,10 +10,10 @@ function handle_callbacks!(integrator)
time,upcrossing,event_occured,idx,counter =
find_first_continuous_callback(integrator,continuous_callbacks...)
if event_occured
integrator.event_last_time = true
integrator.event_last_time = idx
continuous_modified,saved_in_cb = apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
else
integrator.event_last_time = false
integrator.event_last_time = 0
end
end
if !(typeof(discrete_callbacks)<:Tuple{})
Expand Down
6 changes: 3 additions & 3 deletions src/common_interface/solve.jl
Expand Up @@ -229,7 +229,7 @@ function DiffEqBase.__init(
timeseries_errors,dense_errors,save_end,
callbacks_internal,verbose,advance_to_tstop,stop_at_next_tstop)
CVODEIntegrator(u0,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,false)
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0)
end # function solve

function DiffEqBase.__init(
Expand Down Expand Up @@ -496,7 +496,7 @@ function DiffEqBase.__init(
timeseries_errors,dense_errors,save_end,
callbacks_internal,verbose,advance_to_tstop,stop_at_next_tstop)
ARKODEIntegrator(utmp,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,false)
tout,tdir,sizeu,false,tmp,uprev,Cint(flag),false,0)
end # function solve

function tstop_saveat_disc_handling(tstops,saveat,tdir,tspan,tType)
Expand Down Expand Up @@ -769,7 +769,7 @@ function DiffEqBase.__init(
callbacks_internal,verbose,advance_to_tstop,stop_at_next_tstop)

IDAIntegrator(utmp,dutmp,prob.p,t0,t0,mem,_LS,_A,sol,alg,f!,userfun,jac,opts,
tout,tdir,sizeu,sizedu,false,tmp,uprev,Cint(flag),false,false)
tout,tdir,sizeu,sizedu,false,tmp,uprev,Cint(flag),false,0)
end # function solve

## Common calls
Expand Down

0 comments on commit 0fc57e9

Please sign in to comment.