Description
Background
Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various do_step
backtraces, with most of the runtime is in the find_first_continuous_callback
call. Of that, the vast majority is happening in Base.getindex(A:ExtendedJumpArray)
Problem
In particular, it looks like the combination of ExtendedJumpArray
's, @muladd
, and FastBroadcast
is causing unoptimal code generation. For my use case, the problematic function is here:
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Tsit5ConstantCache, Tsit5Cache},
idxs::Nothing, T::Type{Val{0}})
@tsit5pre0
@inbounds @.. broadcast=false out=y₀ +
dt *
(k[1] * b1Θ + k[2] * b2Θ + k[3] * b3Θ + k[4] * b4Θ +
k[5] * b5Θ + k[6] * b6Θ + k[7] * b7Θ)
out
end
but it seems to affect others as well.
Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.
Solution
I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the .u
and .jump_u
members, if we could somehow turn these calls into separate calls like
out.u = y0.u + dt * ...
out.jump_u = y0.jump_u + dt * ...
I fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the .u
and .jump_u
parts unrolled):
using MuladdMacro, FastBroadcast
using OrdinaryDiffEq: constvalue, Tsit5Interp
function OrdinaryDiffEq._ode_interpolant!(out::ExtendedJumpArray, Θ, dt, y₀::ExtendedJumpArray, y₁, k,
cache::Union{OrdinaryDiffEq.Tsit5ConstantCache, OrdinaryDiffEq.Tsit5Cache},
idxs::Nothing, T::Type{Val{0}})
OrdinaryDiffEq.@tsit5unpack
Θ² = Θ * Θ
b1Θ = Θ * @evalpoly(Θ, r11, r12, r13, r14)
b2Θ = Θ² * @evalpoly(Θ, r22, r23, r24)
b3Θ = Θ² * @evalpoly(Θ, r32, r33, r34)
b4Θ = Θ² * @evalpoly(Θ, r42, r43, r44)
b5Θ = Θ² * @evalpoly(Θ, r52, r53, r54)
b6Θ = Θ² * @evalpoly(Θ, r62, r63, r64)
b7Θ = Θ² * @evalpoly(Θ, r72, r73, r74)
@muladd @inbounds FastBroadcast.@.. broadcast=false out.u=y₀.u +
dt *
(k[1].u * b1Θ + k[2].u * b2Θ + k[3].u * b3Θ + k[4].u * b4Θ +
k[5].u * b5Θ + k[6].u * b6Θ + k[7].u * b7Θ)
@muladd @inbounds FastBroadcast.@.. broadcast=false out.jump_u=y₀.jump_u +
dt *
(k[1].jump_u * b1Θ + k[2].jump_u * b2Θ + k[3].jump_u * b3Θ + k[4].jump_u * b4Θ +
k[5].jump_u * b5Θ + k[6].jump_u * b6Θ + k[7].jump_u * b7Θ)
out
end
This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in handle_callbacks
, but I'll deal with that with a separate issue/PR.
Questions
- Is it possible to write a dispatch for the
@..
macro to do this type of code generation automatically? - If not, an annoying approach would be to manually add ExtendedJumpArray dispatches to interpolants.jl, but that's a lot more fragile. I'm also not sure where this dispatch would go, as adding this dispatch into JumpProcesses would also come with a dependency on
OrdinaryDiffEq
. - Is there a better approach?