Skip to content

cuda.compute: Stateful operators are slow because they almost always need recompilation #7498

@shwina

Description

@shwina

In #7008, we added support for stateful ops in cuda.compute. However, it has been found to be slow (see for example scikit-hep/awkward#3814 (comment)).

This is because, in practice, stateful ops are used something like this:

def library_func(x, y, array):

    def stateful_op(x):
        # reference `array` here

    cuda.compute.transform(..., stateful_op, ...)

Above, we hope that the first call to library_func will trigger compilation of stateful_op, but successive calls will not.

But it is compiled on every call. This is because on every call, the op references a different global named array. The way numba-cuda works is that the pointer of the globally referenced array is captured as a constant before compilation. When the referenced array changes, it's treated as a different constant leading to recompilation.

The solution

The solution is to use CCCL.c's native support for stateful operations. Stateful operations pass state to the library via an additional void* state parameter; so that stateful operators have the signature:

void (void* state, void* arg1, ..., void* argN, void* result_ptr)

(see here)

So, on the Python side it's our job to take a stateful operator like:

# example: stateful unary operator

def foo(x):
    return x + y[0]  # `y` is a global/closure

And translate it to something like:

void type_erased_f(void* state, void* x, void* result_ptr) {}

(for ODR reasons, all arguments must be declared void*).

To achieve this, we first do some Python-side AST manipulation to transform a function referencing external arrays into one that takes them as explciit arguments:

# given:
def foo(x):
    return x + y[0]

# transform it to:
def foo(x, y):
    return x + y

We then use numba to define and compile the wrapper function with the correct signature (void(void* state, void* x, void* result_ptr)). Within the wrapper, we interpret the void* as a packed array of pointers to externally referenced arrays and unpack them into array objects. We then pass these to the transformed function above.

State updates

On every invocation, the state (which holds pointers to externally referenced arrays) must be updated to hold the correct pointers. This introduces an API problem for cuda.compute. Consider the (object-based) API of reduce_into:

# step 1: create algorithm object
transformer = cuda.compute.make_unary_transform(d_input, d_output, some_unary_op)

# step 2: invoke algorithm (can be done several times, with differing inputs)
transformer(d_in1, d_out1, num_items1)
transformer(d_in2, d_out2, num_items2)

The problem is that we assume that some_unary_op has no state that can change between construction time and invocation(s). This is no longer the case, as some_unary_op's state can change by changing the arrays it references (as would happen in the motivating example presented in this issue description).

To account for this, we must have the user pass the op at invocation time as well. This can be the same callable with its globals/closures changed, or a an entirely different callable with the same bytecode but different globals/closures. This would make the API:

# step 1: create algorithm object
transformer = cuda.compute.make_unary_transform(d_input, d_output, some_unary_op)

# step 2: invoke algorithm (can be done several times, with differing inputs)
transformer(d_in1, d_out1, some_unary_op, num_items1)  # note, passing some_unary_op to __call__
transformer(d_in2, d_out2, some_unary_op, num_items2)

In the __call__ method, we extract the state from the passed callable.

Performance

This extraction of the state from the user-provided callable is not trivial, especially since we first have to figure out if there is any state to extract (which involves inspecting the function object's globals/closures for any device arrays). It's not a huge cost, but it's not non-trivial (~1us).

As a follow-up, we should find a way to minimize these costs, or at the very least provide a path to skip the logic when the user knows they have a stateless op.

API breakage

As mentioned before, this would break the low-level object-based API for most algorithms which accept ops. The high-level "single-phase" APIs remain the same.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions