From 9c350f514d8699449b95eeab0221eb5e64ddf689 Mon Sep 17 00:00:00 2001 From: Christopher Johnstone Date: Tue, 22 Aug 2023 16:43:18 -0400 Subject: [PATCH 1/3] Apply callbacks with a type-stable generated function. Currently, as referenced in SciML/DifferentialEquations.jl#971, the old implementation of `handle_callbacks!` directly calls. `apply_callback!` on `continuous_callbacks[idx]`, which is inherently type-unstable because `apply_callback!` is specialized on the callback type. This commit adds a generated function `apply_ith_callback!` which generates type-stable code to do the same thing, where for each callback tuple type, the generated function unrolls the tuple by checking the callback index against static indicies. As a nice bonus, this generated function seems to often be converted into a switch statement at the LLVM level: ``` switch i64 %4, label %L46 [ i64 9, label %L3 i64 8, label %L8 i64 7, label %L13 i64 6, label %L18 i64 5, label %L23 i64 4, label %L28 i64 3, label %L33 i64 2, label %L38 i64 1, label %L43 ] ``` For testing, I added an allocation test which sets up a simple ODE problem, steps the integrator manually to before the first callback, then manipulates integrator state past the first callback point. This way, we can directly call `handle_callbacks!` and write a test on the allocation count. I confirm that (at least testing against commit SciML/DiffEqBase.jl@1799fc3, the current master branch tip in DiffEqBase.jl), the new method does not allocate, whereas the old one allocates. This may not be the case until a new release is cut of DiffEqBase.jl, because the old version of `find_first_continuous_callback` might allocate. --- src/integrators/integrator_utils.jl | 39 +++++++++++++--- test/integrators/callback_allocation_tests.jl | 45 +++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 79 insertions(+), 6 deletions(-) create mode 100644 test/integrators/callback_allocation_tests.jl diff --git a/src/integrators/integrator_utils.jl b/src/integrators/integrator_utils.jl index 6f377358fc..cdc80406fd 100644 --- a/src/integrators/integrator_utils.jl +++ b/src/integrators/integrator_utils.jl @@ -285,6 +285,31 @@ function _loopfooter!(integrator) nothing end +# Use a generated function to call apply_callback! in a type-stable way +@generated function apply_ith_callback!(integrator, + time, upcrossing, event_idx, cb_idx, + callbacks::NTuple{N, + Union{ContinuousCallback, + VectorContinuousCallback}}) where {N} + ex = quote + throw(BoundsError(callbacks, cb_idx)) + end + for i in 1:N + # N.B: doing this as an explicit if (return) else (rest of expression) + # means that LLVM compiles this into a switch. + # This seemingly isn't the case with just if (return) end (rest of expression) + ex = quote + if (cb_idx == $i) + return DiffEqBase.apply_callback!(integrator, callbacks[$i], time, + upcrossing, event_idx) + else + $ex + end + end + end + ex +end + function handle_callbacks!(integrator) discrete_callbacks = integrator.opts.callback.discrete_callbacks continuous_callbacks = integrator.opts.callback.continuous_callbacks @@ -295,14 +320,15 @@ function handle_callbacks!(integrator) saved_in_cb = false if !(typeof(continuous_callbacks) <: Tuple{}) time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator, - continuous_callbacks...) + continuous_callbacks...) if event_occurred integrator.event_last_time = idx integrator.vector_event_last_time = event_idx - continuous_modified, saved_in_cb = DiffEqBase.apply_callback!(integrator, - continuous_callbacks[idx], - time, upcrossing, - event_idx) + continuous_modified, saved_in_cb = apply_ith_callback!(integrator, + time, upcrossing, + event_idx, + idx, + continuous_callbacks) else integrator.event_last_time = 0 integrator.vector_event_last_time = 1 @@ -310,7 +336,7 @@ function handle_callbacks!(integrator) end if !integrator.force_stepfail && !(typeof(discrete_callbacks) <: Tuple{}) discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator, - discrete_callbacks...) + discrete_callbacks...) end if !saved_in_cb savevalues!(integrator) @@ -321,6 +347,7 @@ function handle_callbacks!(integrator) integrator.do_error_check = false handle_callback_modifiers!(integrator) end + nothing end function handle_callback_modifiers!(integrator::ODEIntegrator) diff --git a/test/integrators/callback_allocation_tests.jl b/test/integrators/callback_allocation_tests.jl new file mode 100644 index 0000000000..33949752a2 --- /dev/null +++ b/test/integrators/callback_allocation_tests.jl @@ -0,0 +1,45 @@ +using OrdinaryDiffEq, Test + +# Setup a simple ODE problem with several callbacks (to test LLVM code gen) +# We will manually trigger the first callback and check its allocations. +function f!(du, u, p, t) + du .= .-u +end + +cond_1(u, t, integrator) = u[1] - 0.5 +cond_2(u, t, integrator) = u[2] + 0.5 +cond_3(u, t, integrator) = u[2] + 0.6 +cond_4(u, t, integrator) = u[2] + 0.7 +cond_5(u, t, integrator) = u[2] + 0.8 +cond_6(u, t, integrator) = u[2] + 0.9 +cond_7(u, t, integrator) = u[2] + 0.1 +cond_8(u, t, integrator) = u[2] + 0.11 +cond_9(u, t, integrator) = u[2] + 0.12 + +function cb_affect!(integrator) + integrator.p[1] += 1 +end + +cbs = CallbackSet(ContinuousCallback(cond_1, cb_affect!), + ContinuousCallback(cond_2, cb_affect!), + ContinuousCallback(cond_3, cb_affect!), + ContinuousCallback(cond_4, cb_affect!), + ContinuousCallback(cond_5, cb_affect!), + ContinuousCallback(cond_6, cb_affect!), + ContinuousCallback(cond_7, cb_affect!), + ContinuousCallback(cond_8, cb_affect!), + ContinuousCallback(cond_9, cb_affect!)) + +integrator = init(ODEProblem(f!, [0.8, 1.0], (0.0, 100.0), [0, 0]), Tsit5(), callback = cbs, + save_on = false); +# Force a callback event to occur so we can call handle_callbacks! directly. +# Step to a point where u[1] is still > 0.5, so we can force it below 0.5 and +# call handle callbacks +step!(integrator, 0.1, true) + +function handle_allocs(integrator) + integrator.u[1] = 0.4 + @allocations OrdinaryDiffEq.handle_callbacks!(integrator) +end +handle_allocs(integrator); +@test handle_allocs(integrator) == 0 diff --git a/test/runtests.jl b/test/runtests.jl index e1e2e85b30..aef475dbe2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,6 +81,7 @@ end @time @safetestset "Events Tests" include("integrators/ode_event_tests.jl") @time @safetestset "Alg Events Tests" include("integrators/alg_events_tests.jl") @time @safetestset "Discrete Callback Dual Tests" include("integrators/discrete_callback_dual_test.jl") + @time @safetestset "Callback Allocation Tests" include("integrators/callback_allocation_tests.jl") @time @safetestset "Iterator Tests" include("integrators/iterator_tests.jl") @time @safetestset "Integrator Interface Tests" include("integrators/integrator_interface_tests.jl") @time @safetestset "Error Check Tests" include("integrators/check_error.jl") From 918c8d92be102fa672d0237178ea3d4345ff13e8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 24 Aug 2023 09:29:58 +0200 Subject: [PATCH 2/3] Update ode_cache_tests.jl --- test/integrators/ode_cache_tests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/integrators/ode_cache_tests.jl b/test/integrators/ode_cache_tests.jl index 249ebaf8f0..30069d4f94 100644 --- a/test/integrators/ode_cache_tests.jl +++ b/test/integrators/ode_cache_tests.jl @@ -37,8 +37,8 @@ end affect! = function (integrator) u = integrator.u - resize!(integrator, length(u) + 1) maxidx = findmax(u)[2] + resize!(integrator, length(u) + 1) Θ = rand() / 5 + 0.25 u[maxidx] = Θ u[end] = 1 - Θ From 264edd9144c82363550d31cdf50ff394cd94dada Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 24 Aug 2023 09:30:42 +0200 Subject: [PATCH 3/3] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d6f429afd7..0d21a9c40e 100644 --- a/Project.toml +++ b/Project.toml @@ -49,7 +49,7 @@ ADTypes = "0.1, 0.2" Adapt = "1.1, 2.0, 3.0" ArrayInterface = "6, 7" DataStructures = "0.18" -DiffEqBase = "6.125.0" +DiffEqBase = "6.128.2" DocStringExtensions = "0.8, 0.9" ExponentialUtilities = "1.22" FastBroadcast = "0.1.9, 0.2"