Skip to content

Commit

Permalink
Merge pull request #2021 from meson800/apply_callback_type_stability
Browse files Browse the repository at this point in the history
Apply callbacks with a type-stable generated function.
  • Loading branch information
ChrisRackauckas committed Aug 24, 2023
2 parents da20096 + 264edd9 commit bbc1f6a
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ADTypes = "0.1, 0.2"
Adapt = "1.1, 2.0, 3.0"
ArrayInterface = "6, 7"
DataStructures = "0.18"
DiffEqBase = "6.125.0"
DiffEqBase = "6.128.2"
DocStringExtensions = "0.8, 0.9"
ExponentialUtilities = "1.22"
FastBroadcast = "0.1.9, 0.2"
Expand Down
39 changes: 33 additions & 6 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,31 @@ function _loopfooter!(integrator)
nothing
end

# Use a generated function to call apply_callback! in a type-stable way
@generated function apply_ith_callback!(integrator,
time, upcrossing, event_idx, cb_idx,
callbacks::NTuple{N,
Union{ContinuousCallback,
VectorContinuousCallback}}) where {N}
ex = quote
throw(BoundsError(callbacks, cb_idx))
end
for i in 1:N
# N.B: doing this as an explicit if (return) else (rest of expression)
# means that LLVM compiles this into a switch.
# This seemingly isn't the case with just if (return) end (rest of expression)
ex = quote
if (cb_idx == $i)
return DiffEqBase.apply_callback!(integrator, callbacks[$i], time,
upcrossing, event_idx)
else
$ex
end
end
end
ex
end

function handle_callbacks!(integrator)
discrete_callbacks = integrator.opts.callback.discrete_callbacks
continuous_callbacks = integrator.opts.callback.continuous_callbacks
Expand All @@ -295,22 +320,23 @@ function handle_callbacks!(integrator)
saved_in_cb = false
if !(typeof(continuous_callbacks) <: Tuple{})
time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator,
continuous_callbacks...)
continuous_callbacks...)
if event_occurred
integrator.event_last_time = idx
integrator.vector_event_last_time = event_idx
continuous_modified, saved_in_cb = DiffEqBase.apply_callback!(integrator,
continuous_callbacks[idx],
time, upcrossing,
event_idx)
continuous_modified, saved_in_cb = apply_ith_callback!(integrator,
time, upcrossing,
event_idx,
idx,
continuous_callbacks)
else
integrator.event_last_time = 0
integrator.vector_event_last_time = 1
end
end
if !integrator.force_stepfail && !(typeof(discrete_callbacks) <: Tuple{})
discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,
discrete_callbacks...)
discrete_callbacks...)
end
if !saved_in_cb
savevalues!(integrator)
Expand All @@ -321,6 +347,7 @@ function handle_callbacks!(integrator)
integrator.do_error_check = false
handle_callback_modifiers!(integrator)
end
nothing
end

function handle_callback_modifiers!(integrator::ODEIntegrator)
Expand Down
45 changes: 45 additions & 0 deletions test/integrators/callback_allocation_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using OrdinaryDiffEq, Test

# Setup a simple ODE problem with several callbacks (to test LLVM code gen)
# We will manually trigger the first callback and check its allocations.
function f!(du, u, p, t)
du .= .-u
end

cond_1(u, t, integrator) = u[1] - 0.5
cond_2(u, t, integrator) = u[2] + 0.5
cond_3(u, t, integrator) = u[2] + 0.6
cond_4(u, t, integrator) = u[2] + 0.7
cond_5(u, t, integrator) = u[2] + 0.8
cond_6(u, t, integrator) = u[2] + 0.9
cond_7(u, t, integrator) = u[2] + 0.1
cond_8(u, t, integrator) = u[2] + 0.11
cond_9(u, t, integrator) = u[2] + 0.12

function cb_affect!(integrator)
integrator.p[1] += 1
end

cbs = CallbackSet(ContinuousCallback(cond_1, cb_affect!),
ContinuousCallback(cond_2, cb_affect!),
ContinuousCallback(cond_3, cb_affect!),
ContinuousCallback(cond_4, cb_affect!),
ContinuousCallback(cond_5, cb_affect!),
ContinuousCallback(cond_6, cb_affect!),
ContinuousCallback(cond_7, cb_affect!),
ContinuousCallback(cond_8, cb_affect!),
ContinuousCallback(cond_9, cb_affect!))

integrator = init(ODEProblem(f!, [0.8, 1.0], (0.0, 100.0), [0, 0]), Tsit5(), callback = cbs,
save_on = false);
# Force a callback event to occur so we can call handle_callbacks! directly.
# Step to a point where u[1] is still > 0.5, so we can force it below 0.5 and
# call handle callbacks
step!(integrator, 0.1, true)

function handle_allocs(integrator)
integrator.u[1] = 0.4
@allocations OrdinaryDiffEq.handle_callbacks!(integrator)
end
handle_allocs(integrator);
@test handle_allocs(integrator) == 0
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ end
@time @safetestset "Events Tests" include("integrators/ode_event_tests.jl")
@time @safetestset "Alg Events Tests" include("integrators/alg_events_tests.jl")
@time @safetestset "Discrete Callback Dual Tests" include("integrators/discrete_callback_dual_test.jl")
@time @safetestset "Callback Allocation Tests" include("integrators/callback_allocation_tests.jl")
@time @safetestset "Iterator Tests" include("integrators/iterator_tests.jl")
@time @safetestset "Integrator Interface Tests" include("integrators/integrator_interface_tests.jl")
@time @safetestset "Error Check Tests" include("integrators/check_error.jl")
Expand Down

0 comments on commit bbc1f6a

Please sign in to comment.