Skip to content

Efficency of ExtendedJumpArray broadcasting in ode_interpolant #335

Closed
@meson800

Description

@meson800

Background

Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various do_step backtraces, with most of the runtime is in the find_first_continuous_callback call. Of that, the vast majority is happening in Base.getindex(A:ExtendedJumpArray)

Problem

In particular, it looks like the combination of ExtendedJumpArray's, @muladd, and FastBroadcast is causing unoptimal code generation. For my use case, the problematic function is here:

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
    cache::Union{Tsit5ConstantCache, Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    @tsit5pre0
    @inbounds @.. broadcast=false out=y₀ +
                                      dt *
                                      (k[1] * b1Θ + k[2] * b2Θ + k[3] * b3Θ + k[4] * b4Θ +
                                       k[5] * b5Θ + k[6] * b6Θ + k[7] * b7Θ)
    out
end

but it seems to affect others as well.

Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.

Solution

I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the .u and .jump_u members, if we could somehow turn these calls into separate calls like

out.u = y0.u + dt * ...
out.jump_u = y0.jump_u + dt * ...

I fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the .u and .jump_u parts unrolled):

using MuladdMacro, FastBroadcast
using OrdinaryDiffEq: constvalue, Tsit5Interp
function OrdinaryDiffEq._ode_interpolant!(out::ExtendedJumpArray, Θ, dt, y₀::ExtendedJumpArray, y₁, k,
    cache::Union{OrdinaryDiffEq.Tsit5ConstantCache, OrdinaryDiffEq.Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    OrdinaryDiffEq.@tsit5unpack
    Θ² = Θ * Θ
    b1Θ = Θ * @evalpoly(Θ, r11, r12, r13, r14)
    b2Θ = Θ² * @evalpoly(Θ, r22, r23, r24)
    b3Θ = Θ² * @evalpoly(Θ, r32, r33, r34)
    b4Θ = Θ² * @evalpoly(Θ, r42, r43, r44)
    b5Θ = Θ² * @evalpoly(Θ, r52, r53, r54)
    b6Θ = Θ² * @evalpoly(Θ, r62, r63, r64)
    b7Θ = Θ² * @evalpoly(Θ, r72, r73, r74)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.u=y₀.u +
                                      dt *
                                      (k[1].u * b1Θ + k[2].u * b2Θ + k[3].u * b3Θ + k[4].u * b4Θ +
                                       k[5].u * b5Θ + k[6].u * b6Θ + k[7].u * b7Θ)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.jump_u=y₀.jump_u +
                                      dt *
                                      (k[1].jump_u * b1Θ + k[2].jump_u * b2Θ + k[3].jump_u * b3Θ + k[4].jump_u * b4Θ +
                                       k[5].jump_u * b5Θ + k[6].jump_u * b6Θ + k[7].jump_u * b7Θ)
    out
end

This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in handle_callbacks, but I'll deal with that with a separate issue/PR.

Questions

  • Is it possible to write a dispatch for the @.. macro to do this type of code generation automatically?
  • If not, an annoying approach would be to manually add ExtendedJumpArray dispatches to interpolants.jl, but that's a lot more fragile. I'm also not sure where this dispatch would go, as adding this dispatch into JumpProcesses would also come with a dependency on OrdinaryDiffEq.
  • Is there a better approach?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions