Add CPU callbacks for stream capture & Cythonize GraphBuilder#1814
Add CPU callbacks for stream capture & Cythonize GraphBuilder#1814Andy-Jost wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
Move the GraphBuilder/Graph/GraphCompleteOptions/GraphDebugPrintOptions implementation out of _graph/__init__.py into _graph/_graph_builder.pyx so it is compiled by Cython. A thin __init__.py re-exports the public names so all existing import sites continue to work unchanged. Cython compatibility adjustments: - Remove `from __future__ import annotations` (unsupported by Cython) - Remove TYPE_CHECKING guard; quote annotations that reference Stream (circular import), forward-reference GraphBuilder/Graph, or use X | None union syntax - Update _graphdef.pyx lazy imports to point directly at _graph_builder No build_hooks.py changes needed — the build system auto-discovers .pyx files via glob. Ref: NVIDIA#1076 Made-with: Cursor
Replace the per-module _lazy_init / _inited / _driver_ver / _py_major_minor pattern in _graph_builder.pyx with direct calls to centralized cached functions in cuda_utils: - Add get_driver_version() with @functools.cache alongside get_binding_version - Switch get_binding_version from @functools.lru_cache to @functools.cache (cleaner for nullary functions) - Fix split() to return tuple(result) — Cython enforces return type annotations unlike pure Python - Fix _cond_with_params annotation from -> GraphBuilder to -> tuple to match actual return value Made-with: Cursor
There was a problem hiding this comment.
This file was moved to _graph/_graph_builder.pyx and replaced with a thin re-exporter
There was a problem hiding this comment.
Moved from _graph/__init__.py. Changes to this file:
- Replaced explicit
_lazy_initwith direct calls to cached functionsget_binding_versionandget_driver_version - Added
GraphBuilder.callback
| def callback(self, fn, *, user_data=None): | ||
| """Add a host callback to the graph during stream capture. | ||
|
|
||
| The callback runs on the host CPU when the graph reaches this point | ||
| in execution. Two modes are supported: | ||
|
|
||
| - **Python callable**: Pass any callable. The GIL is acquired | ||
| automatically. The callable must take no arguments; use closures | ||
| or ``functools.partial`` to bind state. | ||
| - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. | ||
| The function receives a single ``void*`` argument (the | ||
| ``user_data``). The caller must keep the ctypes wrapper alive | ||
| for the lifetime of the graph. | ||
|
|
||
| .. warning:: | ||
|
|
||
| Callbacks must not call CUDA API functions. Doing so may | ||
| deadlock or corrupt driver state. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| fn : callable or ctypes function pointer | ||
| The callback function. | ||
| user_data : int or bytes-like, optional | ||
| Only for ctypes function pointers. If ``int``, passed as a raw | ||
| pointer (caller manages lifetime). If bytes-like, the data is | ||
| copied and its lifetime is tied to the graph. | ||
| """ | ||
| cdef Stream stream = <Stream>self._mnff.stream | ||
| cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) | ||
| cdef cydriver.CUstreamCaptureStatus capture_status | ||
| cdef cydriver.CUgraph c_graph = NULL | ||
|
|
||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( | ||
| c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) | ||
|
|
||
| if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: | ||
| raise RuntimeError("Cannot add callback when graph is not being built") | ||
|
|
||
| cdef cydriver.CUhostFn c_fn | ||
| cdef void* c_user_data = NULL | ||
| _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) | ||
|
|
||
| with nogil: | ||
| HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) |
Implements NVIDIA#1328: host callbacks during stream capture via cuLaunchHostFunc, mirroring the existing GraphDef.callback API. Extracts shared callback infrastructure (_attach_user_object, _attach_host_callback_to_graph, trampoline/destructor) into a new _graph/_utils.pyx module to avoid circular imports between _graph_builder and _graphdef. Made-with: Cursor
39e5c57 to
edbc361
Compare
cpcloud
left a comment
There was a problem hiding this comment.
GraphBuilder.callback() in cuda_core/cuda/core/_graph/_graph_builder.pyx now calls cydriver.cuStreamGetCaptureInfo(...) directly. That Cython symbol is only generated when the bindings headers expose cuStreamGetCaptureInfo_v3 (cuda_bindings/cuda/bindings/cydriver.pxd.in), but CI still rebuilds cuda_core against the previous supported CUDA major (12.9.1).
That matches the PR's current Linux build failures in the second Build cuda.core wheel phase. Please switch this path back to the existing Python wrapper (driver.cuStreamGetCaptureInfo(...)) or otherwise gate/fallback the direct C call so the callback implementation still builds against the CUDA 12 compatibility configuration.
|
|
Summary
GraphBuilder.callback(), mirroring the existingGraphDef.callback()API_graph/_graph_builder.pyx(converts from pure Python to Cython.pyx)_graph/_utils.pyxto avoid circular importsChanges
_graph/_utils.pyx/_graph/_utils.pxd(new): Shared callback infrastructure —_attach_user_object,_attach_host_callback_to_graph,_py_host_trampoline/_py_host_destructor, and_is_py_host_trampolinehelper_graph/_graph_builder.pyx: Converted from.pyto.pyx; addedcallback()method usingcuLaunchHostFunc; centralized version caching viaget_driver_version()/get_binding_version()(removed per-module_lazy_init)_graph/_graphdef.pyx: RefactoredGraphNode.callback()to use the shared_attach_host_callback_to_graphhelper; removed duplicated callback infrastructureTest Coverage
tests/graph/test_basic.py: Python callable callback, ctypes CFuncPtr callback withuser_data, anduser_datarejection for Python callablesRelated Work
Made with Cursor