Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficency of ExtendedJumpArray broadcasting in ode_interpolant #335

Closed
meson800 opened this issue Jul 20, 2023 · 10 comments · Fixed by #340
Closed

Efficency of ExtendedJumpArray broadcasting in ode_interpolant #335

meson800 opened this issue Jul 20, 2023 · 10 comments · Fixed by #340

Comments

@meson800
Copy link
Contributor

meson800 commented Jul 20, 2023

Background

Continuing to optimize a system with VariableRateJumps and callbacks, I've found that performing ODE interpolations on ExtendedJumpArray's is consuming around ~80% of the total runtime of the solve. About 10% is actually in the various do_step backtraces, with most of the runtime is in the find_first_continuous_callback call. Of that, the vast majority is happening in Base.getindex(A:ExtendedJumpArray)

Problem

In particular, it looks like the combination of ExtendedJumpArray's, @muladd, and FastBroadcast is causing unoptimal code generation. For my use case, the problematic function is here:

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
    cache::Union{Tsit5ConstantCache, Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    @tsit5pre0
    @inbounds @.. broadcast=false out=y₀ +
                                      dt *
                                      (k[1] * b1Θ + k[2] * b2Θ + k[3] * b3Θ + k[4] * b4Θ +
                                       k[5] * b5Θ + k[6] * b6Θ + k[7] * b7Θ)
    out
end

but it seems to affect others as well.

Examining this with Cthulhu and friends, looking at the LLVM and native code, etc etc shows an absolute morass of branching, with something like 4800 (!) branches in the assembly. The branches are coming from inlining where it's using the ExtendedJumpArray index check. This churn means that the compiler doesn't seem to be able to loop unroll, since it has to repeatedly check "is index < length(jump_array.u)", over and over and over.

Solution

I'm guessing that this would be a lot faster and not be churning the branch predictor with a bunch of mostly-useless (<=) calls to switch between the .u and .jump_u members, if we could somehow turn these calls into separate calls like

out.u = y0.u + dt * ...
out.jump_u = y0.jump_u + dt * ...

I fixed this by adding this dispatch to my usercode (I unpacked one of the macros because getting the imports right was annoying, but it's just the original ode_interpolant with the .u and .jump_u parts unrolled):

using MuladdMacro, FastBroadcast
using OrdinaryDiffEq: constvalue, Tsit5Interp
function OrdinaryDiffEq._ode_interpolant!(out::ExtendedJumpArray, Θ, dt, y₀::ExtendedJumpArray, y₁, k,
    cache::Union{OrdinaryDiffEq.Tsit5ConstantCache, OrdinaryDiffEq.Tsit5Cache},
    idxs::Nothing, T::Type{Val{0}})
    OrdinaryDiffEq.@tsit5unpack
    Θ² = Θ * Θ
    b1Θ = Θ * @evalpoly(Θ, r11, r12, r13, r14)
    b2Θ = Θ² * @evalpoly(Θ, r22, r23, r24)
    b3Θ = Θ² * @evalpoly(Θ, r32, r33, r34)
    b4Θ = Θ² * @evalpoly(Θ, r42, r43, r44)
    b5Θ = Θ² * @evalpoly(Θ, r52, r53, r54)
    b6Θ = Θ² * @evalpoly(Θ, r62, r63, r64)
    b7Θ = Θ² * @evalpoly(Θ, r72, r73, r74)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.u=y₀.u +
                                      dt *
                                      (k[1].u * b1Θ + k[2].u * b2Θ + k[3].u * b3Θ + k[4].u * b4Θ +
                                       k[5].u * b5Θ + k[6].u * b6Θ + k[7].u * b7Θ)
    @muladd @inbounds FastBroadcast.@.. broadcast=false out.jump_u=y₀.jump_u +
                                      dt *
                                      (k[1].jump_u * b1Θ + k[2].jump_u * b2Θ + k[3].jump_u * b3Θ + k[4].jump_u * b4Θ +
                                       k[5].jump_u * b5Θ + k[6].jump_u * b6Θ + k[7].jump_u * b7Θ)
    out
end

This dramatically reduces the runtime, and inspecting the LLVM code shows "only" 123 branches with good loop unrolling. There's still a boat-load of allocations happening in handle_callbacks, but I'll deal with that with a separate issue/PR.

Questions

  • Is it possible to write a dispatch for the @.. macro to do this type of code generation automatically?
  • If not, an annoying approach would be to manually add ExtendedJumpArray dispatches to interpolants.jl, but that's a lot more fragile. I'm also not sure where this dispatch would go, as adding this dispatch into JumpProcesses would also come with a dependency on OrdinaryDiffEq.
  • Is there a better approach?
@meson800 meson800 changed the title Efficency of ExtendedJumpArray broadcasting Efficency of ExtendedJumpArray broadcasting in ode_interpolant Jul 20, 2023
@isaacsas
Copy link
Member

I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain. Dispatches that are not ODE-method dependent would be fine (i.e. things that are used by all / many methods like your last PR, so more dispatches on things in DiffEqBase, SciMLBase, or Julia's Base). But maybe there is a ExtendedJumpArray broadcast feature that just isn't implemented and if dispatched would fix the issue here? (I really have no familiarity with the broadcast API unfortunately, so can't help with suggestions. Maybe @ChrisRackauckas has some suggestions.)

@.. comes from FastBroadcast.jl if you want to look at what it does:

https://github.com/YingboMa/FastBroadcast.jl

@ChrisRackauckas
Copy link
Member

https://github.com/YingboMa/FastBroadcast.jl/blob/master/src/FastBroadcast.jl#L19

Try overriding this to be false so it doesn't use the linear indexing?

I don't think we'd want to add ODE method dependent dispatches. That would be a nightmare to put together and maintain.

Yes, it should just be one thing about broadcast lowering

@meson800
Copy link
Contributor Author

Try overriding this to be false so it doesn't use the linear indexing?

Overriding that to false doesn't seem to solve the problem; there's still a bunch of calls to getindex and specifically x <= length(eja.u)

I've been thinking through the broadcast interface, it feels like what we want is a broadcast style that ends up flattening the broadcast kernel, then applies it on both u and jump_u. It does seem like, at least in 2018, someone was trying to do something similar using the broadcast rules: JuliaLang/julia#27988 (comment)

I'll update if I make any progress. Right now my debugging frustratingly seems to show that on broadcast operations it's not using the defined ExtendedJumpArrayStyle, e.g. this code:

using JumpProcesses
k1 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
k2 = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(5), rand(2))
@code_warntype k1 .+ k2

shows:

Arguments
  #self#::Core.Const(var"##dotfunction#1496#37"())
  x1::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
  x2::ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}
Body::Vector{Float64}
1 ─ %1 = Base.broadcasted(Main.:+, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(+), Tuple{ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}, ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}}}
│   %2 = Base.materialize(%1)::Vector{Float64}
└──      return %2

which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2 when running Pkg.status)

Running k1 + k2 results in an ExtendedJumpArray.

@ChrisRackauckas
Copy link
Member

which is actually just wrong; this should not just return a Vector{Float64}. I don't know why it's doing this, but I checked that this happens even with a clean Julia environment (only [ccbc3e58] JumpProcesses v9.7.2 when running Pkg.status)

This means it's hitting the fallback broadcast style, and when it's using the AbstractArray style then it defaults to returning an Array. And that would use indexing and be type unstable if u and jump_u are not the same.

The thing to look at would probably be https://github.com/jonniedie/ComponentArrays.jl. It solves this kinds of problems quite nicely, so we may just want to lift some of its broadcast implementation. I've also considered completely removing the ExtendedJumpArray and just using a ComponentArray.

Though the issue is that it's somewhat magical to the user if they solve using a an Array of size 5 and then get a solution that's an array of size 7. The reason for this machinery being a bit different is to try to mask from the user the fact that it is constructing and solving a larger system. Potentially, the solution is to just use a ComponentArray but make the saving code allow for just saving the "Array" part.

@meson800
Copy link
Contributor Author

Thanks for the reference! It looks like ComponentArrays actually has very little broadcasting magic, but a much more advanced indexing magic. I'll see if there is a straightforward way to get the ExtendedJumpArray to actually use the broadcasting code that has already been written, otherwise I'll see if there is some indexing magic to use.

@ChrisRackauckas
Copy link
Member

I think the fact that ComponentArrays forces everything to have the same type, thus you cannot end up in the situation where u is Int and jump_u is a Float64, is doing a lot of heavy lifting there. I am unsure if we should just promote the type.

@meson800
Copy link
Contributor Author

Yeah, I did test out getting broadcasting working and ran into something similar. The current broadcast code wasn't actually doing anything, but once I added two more interface functions in Broadcast, broadcasting "works", as in dot operations properly fuse and do what you expect and returns a ExtendedJumpArray. The limitations were:

  • Julia's broadcasting code expects that output arrays are of a single-type, and the ExtendedJumpArray code uses the type of the u array, so even if you have a (u,jump_u) = (Float64, Int64) setup, a single dot broadcast converts it to (Float64, Float64). It miiight be possible to propogate (Float64, Int64) by encoding it into the broadcast style type, but I think this would also require manually writing style mixing rules to implement the expected integer/floating point type promotion rules, which would be gross. (e.g. if we multiply a (Float64, Int64) by pi, we really do expect to get a (Float64, Float64) out).
  • In the ode_interpolant section, I no longer see it take a lot of time in the inefficient get_index....I now see it taking a lot of time in the default materialize! implementation, which FastBroadcasts falls back to. In particular, the @.. macro can't "see" the materialize! calls inside the broadcast machinery so can't convert them to fast_materialize! calls.

I might try out the hacky thing first of just directly calling fast_materialize! inside the ExtendedJumpArray broadcast code, just to see if there are speed improvements. This would definitely not be a PR-worthy solution, but would at least let me know if it could be improved.

@meson800
Copy link
Contributor Author

meson800 commented Aug 28, 2023

I haven't figured out a solution to this so far though I'd like to; even a simple adding two ExtendedJumpArray's together is about five times slower than adding an equivalent Vector{Float64}'s, which feels bad.

Notes so far:

  1. The current broadcasting overloads that define the ExtendedJumpArrayStyle don't do anything, at least in the current Julia. This is because it defines a method for Broadcast.BroadcastStyle but not the Base.BroadcastStyle function, but even if you include this, broadcasting falls back to DefaultArrayStyle. Since it appears unused, I'm considering just replacing it.
  2. To handle dot-broadcasting, the current implementation of copyto! does the efficient thing we want, it unpacks the broadcast call and does the u and jump_u arrays separately. It just doesn't currently get called due to the ExtendedJumpArrayStyle not fully implementing the broadcasting interface.
  3. To address the problem of FastBroadcast as used in ode_interpolant, the @.. macro is replacing calls to materialize! with fast_materialize!. Unfortunately, that means that any special code we put into copyto! won't get called by FastBroadcast. Technically, we could define an overload for FastBroadcast.fast_materialize and FastBroadcast.fast_materialize! that are effectively identical to copyto! and just internally unpack the Broadcasted object.

At the very least, I'm trying to address 1) and 2). There are currently correctness issues with the current fallback to DefaultArrayStyle, like this happily working despite the u/jump_u mismatch:

ExtendedJumpArray(rand(100), rand(10)) .+ ExtendedJumpArray(rand(90), rand(20))

For 3), @ChrisRackauckas, do you think it's ok to add a dependency on FastBroadcast to this package? Or a Julia 1.9 extension if FastBroadcast is loaded? Because FastBroadcast.@.. is first lowering broadcasted code and then just replacing materialize! calls, there's no hook in the Broadcast interface we can use to override its behavior besides defining methods for fast_materialize and fast_materialize!

@meson800
Copy link
Contributor Author

I did it, at least for normal broadcasting! I should have a PR in today. I rewrote the broadcast rules and slightly changed how broadcast repacking works.

I'm working on the FastBroadcast overloads right now, confirming that they work as expected.

For the benchmark, I compare just against a linear vector. The old fallback mechanism used to be ~3-5x as slow, but now it is effectively the same. Checking with Cthulhu shows that efficient simd instructions are being emitted now.

using BenchmarkTools
bench_out_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
                                                                             rand(rng, 500000))
bench_in_array = ExtendedJumpArray{Float64, 1, Vector{Float64}, Vector{Float64}}(rand(rng, 500000),
                                                                             rand(rng, 500000))
base_out_array = rand(rng, 500000 * 2)
base_in_array = rand(rng, 500000 * 2)

function test_single_dot(out, array)
     @inbounds  @. out = array + 1.0 * array + 1.2 * array
end
test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(bench_out_array, bench_in_array)
@benchmark test_single_dot(base_out_array, base_in_array)

benchmarks

@isaacsas
Copy link
Member

I’d be fine with adding a dependency on FastBroadcast if that enables you to fix the issue. Please go ahead and add it if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants