Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing type-instability in callback handling #971

Closed
meson800 opened this issue Jul 22, 2023 · 4 comments
Closed

Fixing type-instability in callback handling #971

meson800 opened this issue Jul 22, 2023 · 4 comments

Comments

@meson800
Copy link

meson800 commented Jul 22, 2023

In the handle_callbacks! function in OrdinaryDiffEq, both the find_first_continuous_callback and apply_callback! lines are either type-unstable or can sometimes fail to be inferred properly. Both can be solved with generated functions, but I'd like feedback on how to PR this (PR into both OrdinaryDiffEq and DiffEqBase? Only add functions to OrdinaryDiffEq?) and if the generated functions have any disadvantages.

Fixing apply_callback! type instability

The line

            continuous_modified, saved_in_cb = DiffEqBase.apply_callback!(integrator,
                continuous_callbacks[idx],
                time, upcrossing,
                event_idx)

allocates due to the indexing into continuous_callbacks[idx]. @code_warntype shows the instability

│    %48 = DiffEqBase.apply_callback!::Core.Const(DiffEqBase.apply_callback!)
│    %49 = Base.getindex(continuous_callbacks, idx)::ContinuousCallback{F1, F2, F3, typeof(SciMLBase.INITIALIZE_DEFAULT), typeof(SciMLBase.FINALIZE_DEFAULT), Float64, Int64, T3, Nothing, Int64} where {F1, F2, F3, T3}
│    %50 = time::Float64
│    %51 = upcrossing::Float64
│    %52 = (%48)(integrator, %49, %50, %51, event_idx)::Core.PartialStruct(Tuple{Bool, Bool}, Any[Bool, Core.Const(true)])

which, at the LLVM level calls out to gc_alloc_obj, and churns a lot of allocations if you have a lot of callbacks that call frequently:

   %current_task10 = getelementptr inbounds {}*, {}** %95, i64 -13
   %96 = call noalias nonnull {}* @julia.gc_alloc_obj({}** %current_task10, i64 1592, {}* inttoptr (i64 3010981741392 to {}*)) #1
   %97 = bitcast {}* %96 -- snip --
   %98 = call {}* @ijl_get_nth_field_checked({}* %96, i64 %94)
  %99 = load {}*, {}** %integrator, align 8

This is fixed by adding a generated overload that generates code that explicitly iterates over the tuple type in a type stable way. It also often seems to inline specialized apply_callback! calls, which is nice as well:

@generated function apply_callback!(integrator, cb_time, prev_sign, event_idx,
    cb_idx, callbacks::NTuple{N, Union{ContinuousCallback, VectorContinuousCallback}}) where {N}
    ex = :(throw(BoundsError("attempt to access $N-length tuple at index $i")))
    for i = 1:N
        ex = quote
            if (cb_idx == $i)
                return apply_callback!(integrator, callbacks[$i], cb_time, prev_sign, event_idx)
            end
            $ex
        end
    end
    ex
end

Generating find_first_continuous_callback to side-step inference failure

Depending on the ODE problem / number of callbacks, sometimes the recursive find_first_continuous_callback functions fail to infer. From an example Cthulhu run, even though stuff should be inferable, the generated function is type unstable which unfortunately does N allocations per internal step for N continuous callbacks :(

  args::Tuple{ContinuousCallback{JumpProcesses.var"#167#169"{Int64}, JumpProcesses.var"#168#170"{Random.TaskLocalRNG, Int64, VariableRateJump{ --snip, this is a tuple with ~7 continuous callbacks}

Body::Any
│   @ c:\Users\Christopher\source\repos\Supercoiling.jl\dev\DiffEqBase\src\callbacks.jl:129 within `find_first_continuous_callback`
│   %5 = Core._apply_iterate(Base.iterate, %1, %2, %3, %4, args)::Any Call inference reached maximally imprecise information. Bailing on. _apply_iterate inference reached maximally imprecise information. Bailing on.:

The recursive definition can relatively straightforwardly be replaced with a generated function:

@generated function find_first_continuous_callback(integrator, callbacks::NTuple{N, AbstractContinuousCallback}) where {N}
    ex = quote
        tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator, callbacks[1], 1)
        identified_idx = 1
    end

    for i = 2:N
        ex = quote
            $ex
            tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, callbacks[$i], $i)
            if event_occurred2 && (tmin2 < tmin || !event_occurred)
                tmin = tmin2
                upcrossing = upcrossing2
                event_occurred = true
                event_idx = event_idx2
                identified_idx = $i
        end
    end
    ex = quote
        $ex
        return tmin, upcrossing, event_occurred, event_idx, identified_idx, $N
    end
    ex
end

This also is type stable and also often inlines the specific find_callback_time calls.

Implementation questions

I currently define these functions in DiffEqBase, and then change the implementation of handle_callbacks! in OrdinaryDiffEq with the versions that doesn't splat the continuous_callbacks tuple:

        time, upcrossing, event_occurred, event_idx, idx, counter = DiffEqBase.find_first_continuous_callback(integrator,
            continuous_callbacks)
        if event_occurred
            integrator.event_last_time = idx
            integrator.vector_event_last_time = event_idx
            continuous_modified, saved_in_cb = DiffEqBase.apply_callback!(integrator,
                time, upcrossing,
                event_idx,
                idx,
                continuous_callbacks)
  • Is this the best way to implement this, e.g. with two separate PRs?
  • At a glance, there don't seem to be explict tests for either of these functions in DiffEqBase. I can add some, or is this something that should be implicitly covered by the tests in OrdinaryDiffEq?
  • Use of the generated function means that every ODEProblem with different callbacks needs to compile this function, which might not be desired depending on the user specialization level. At the same time, I think that these functions are already going to be recompiled no matter what, so is this a performance concern?

Possible alternate implementation

Instead of messing with DiffEqBase, this could be alternatively implemented as making handle_callbacks! into a generated function, or making a helper generated function to replace just these lines, which could be a little nicer.

@ChrisRackauckas
Copy link
Member

Hey, did this ever find its way to a PR? I know you put a few things in but I can't find a PR associated with this, though I swear I remembered an email about it? Or maybe I'm misremembering.

@meson800
Copy link
Author

meson800 commented Aug 9, 2023

I got caught up a bit with the normal PhD grind. I can work on getting something PR-worthy within the next week or so. I was looking for some amount of suggestion on where to PR. My leading idea that I'll PR for unless you have more feedback is:

  1. PR in DiffEqBase.jl to replace the recursive find_first_continuous_callback with a generated function alterative which should type-infer more robustly. This change should be "invisible" to callers. Then, PR in OrdinaryDiffEq.jl, pulling out the logic that calls find_first_continuous_callback and apply_callback! in handle_callbacks! into its own generated function that does this in a type stable way. Two PRs is slightly annoying but I think this is the cleanest.

Alternatively, I could:
2. I could put both of these in a PR in OrdinaryDiffEq.jl, but this would duplicate the find_first_continuous_callback unnecessarily across the packages.
3. Add a apply_callback! dispatch that, instead of passing it a single callback object, you can pass it a tuple of callbacks and an index. This dispatch would enable other code that looks like handle_callbacks! could benefit from type stability without having to add generated functions at their callsite. I don't favor this because the current implementation is fine, it's the surrounding code in handle_callbacks! that causes the instability.

meson800 added a commit to meson800/DiffEqBase.jl that referenced this issue Aug 22, 2023
As mentioned in SciML/DifferentialEquations.jl#971, the current
recursive method for identifying the first continuous callback can cause
the compiler to give up on type inference, especially when there are
many callbacks. The fallback then allocates.

This switches this function to using a generated function (along with an
inline function that takes splatted tuples). Because this generated
function explicitly unrolls the tuple, there are no type inference
problems.

I added a test that allocates using the old implementation (about 19kb
allocations!) but does not with the new system.
meson800 added a commit to meson800/DiffEqBase.jl that referenced this issue Aug 22, 2023
As mentioned in SciML/DifferentialEquations.jl#971, the current
recursive method for identifying the first continuous callback can cause
the compiler to give up on type inference, especially when there are
many callbacks. The fallback then allocates.

This switches this function to using a generated function (along with an
inline function that takes splatted tuples). Because this generated
function explicitly unrolls the tuple, there are no type inference
problems.

I added a test that allocates using the old implementation (about 19kb
allocations!) but does not with the new system.
meson800 added a commit to meson800/OrdinaryDiffEq.jl that referenced this issue Aug 23, 2023
Currently, as referenced in SciML/DifferentialEquations.jl#971, the old
implementation of `handle_callbacks!` directly calls. `apply_callback!`
on `continuous_callbacks[idx]`, which is inherently type-unstable
because `apply_callback!` is specialized on the callback type.

This commit adds a generated function `apply_ith_callback!` which
generates type-stable code to do the same thing, where for each callback
tuple type, the generated function unrolls the tuple by checking the
callback index against static indicies. As a nice bonus, this generated
function seems to often be converted into a switch statement at the LLVM
level:
```
   switch i64 %4, label %L46 [
    i64 9, label %L3
    i64 8, label %L8
    i64 7, label %L13
    i64 6, label %L18
    i64 5, label %L23
    i64 4, label %L28
    i64 3, label %L33
    i64 2, label %L38
    i64 1, label %L43
  ]
```

For testing, I added an allocation test which sets up a simple ODE
problem, steps the integrator manually to before the first callback,
then manipulates integrator state past the first callback point. This
way, we can directly call `handle_callbacks!` and write a test on the
allocation count. I confirm that (at least testing against commit
SciML/DiffEqBase.jl@1799fc3, the current master branch tip in
DiffEqBase.jl), the new method does not allocate, whereas the old one
allocates. This may not be the case until a new release is cut of
DiffEqBase.jl, because the old version of
`find_first_continuous_callback` might allocate.
meson800 added a commit to meson800/OrdinaryDiffEq.jl that referenced this issue Aug 23, 2023
Currently, as referenced in SciML/DifferentialEquations.jl#971, the old
implementation of `handle_callbacks!` directly calls. `apply_callback!`
on `continuous_callbacks[idx]`, which is inherently type-unstable
because `apply_callback!` is specialized on the callback type.

This commit adds a generated function `apply_ith_callback!` which
generates type-stable code to do the same thing, where for each callback
tuple type, the generated function unrolls the tuple by checking the
callback index against static indicies. As a nice bonus, this generated
function seems to often be converted into a switch statement at the LLVM
level:
```
   switch i64 %4, label %L46 [
    i64 9, label %L3
    i64 8, label %L8
    i64 7, label %L13
    i64 6, label %L18
    i64 5, label %L23
    i64 4, label %L28
    i64 3, label %L33
    i64 2, label %L38
    i64 1, label %L43
  ]
```

For testing, I added an allocation test which sets up a simple ODE
problem, steps the integrator manually to before the first callback,
then manipulates integrator state past the first callback point. This
way, we can directly call `handle_callbacks!` and write a test on the
allocation count. I confirm that (at least testing against commit
SciML/DiffEqBase.jl@1799fc3, the current master branch tip in
DiffEqBase.jl), the new method does not allocate, whereas the old one
allocates. This may not be the case until a new release is cut of
DiffEqBase.jl, because the old version of
`find_first_continuous_callback` might allocate.
@meson800
Copy link
Author

This is closable now that both PRs have been merged. Thanks Chris!

@ChrisRackauckas
Copy link
Member

Thanks for your contributions! These are a few things that have bugged me for awhile so I'm glad someone put the work in to fix it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants