-
Notifications
You must be signed in to change notification settings - Fork 59
Description
Is your feature request related to a problem? Please describe.
Currently, after the kernel has been compiled, Numba directly creates the module with the kernel entry point and launch it. Third party libraries sometimes needs to initialize the modules with a host function to use their device APIs. Numba should provide a way to insert module initialization callbacks into its kernel launching precedure.
Describe the solution you'd like
Just an early proposal, please feel free to suggest alternatives.
[EDIT: 022825]
Offline discussion agreed to that the init callback function can be tied with the linkable code object. Third party library must provide their implementation via the linkable code interface, it can be also tied with necessary initialization steps to the module in these functions.
class LinkableCode:
data: str
name: str
init_callbacks: callable | None[Original Post]
We can try adding a new argument to jit call:
@cuda.jit(mod_init_callback=[callbacks,])
def kernel():
...The callback function should follow this signature:
def callback(mod: ctypes.c_void_p, stream: int):
...Alternatively, we can use the cuda-python object model as the signatures.
def callback(mod: cuda.bindings.driver.CUModule, stream: class cuda.bindings.runtime.cudaStream_t):
...Which later gets invoked on the compiled kernel prior to launch in the specified order.
Describe alternatives you've considered
Alternatives are basically different ways to pass in the callback function. Since the compiled module is an implementation detail inside Numba, we can't / shouldn't work on it directly.
Alt1 (pass in via the launch configuration):
kernel[1, 1, stream, smem, modinit]()Alt2 (pass in via the kernel argument):
class ModInitCallback:
...
kernel[...](ModInitCallback(callback))