From 631a74ccb6b142b1e59795750aaf0244af6b5211 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 24 Mar 2026 10:12:01 -0700 Subject: [PATCH 1/8] Cythonize _graph/_graph_builder (move from pure Python to .pyx) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move the GraphBuilder/Graph/GraphCompleteOptions/GraphDebugPrintOptions implementation out of _graph/__init__.py into _graph/_graph_builder.pyx so it is compiled by Cython. A thin __init__.py re-exports the public names so all existing import sites continue to work unchanged. Cython compatibility adjustments: - Remove `from __future__ import annotations` (unsupported by Cython) - Remove TYPE_CHECKING guard; quote annotations that reference Stream (circular import), forward-reference GraphBuilder/Graph, or use X | None union syntax - Update _graphdef.pyx lazy imports to point directly at _graph_builder No build_hooks.py changes needed — the build system auto-discovers .pyx files via glob. Ref: https://github.com/NVIDIA/cuda-python/issues/1076 Made-with: Cursor --- cuda_core/cuda/core/_graph/__init__.py | 805 +----------------- cuda_core/cuda/core/_graph/_graph_builder.pyx | 791 +++++++++++++++++ cuda_core/cuda/core/_graph/_graphdef.pyx | 4 +- 3 files changed, 807 insertions(+), 793 deletions(-) create mode 100644 cuda_core/cuda/core/_graph/_graph_builder.pyx diff --git a/cuda_core/cuda/core/_graph/__init__.py b/cuda_core/cuda/core/_graph/__init__.py index 2f1179312b..635ddfdf37 100644 --- a/cuda_core/cuda/core/_graph/__init__.py +++ b/cuda_core/cuda/core/_graph/__init__.py @@ -1,796 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -import weakref -from dataclasses import dataclass -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from cuda.core._stream import Stream - -from cuda.core._utils.cuda_utils import ( - driver, - get_binding_version, - handle_return, +from cuda.core._graph._graph_builder import ( + Graph, + GraphBuilder, + GraphCompleteOptions, + GraphDebugPrintOptions, + _instantiate_graph, ) -_inited = False -_driver_ver = None - - -def _lazy_init(): - global _inited - if _inited: - return - - global _py_major_minor, _driver_ver - # binding availability depends on cuda-python version - _py_major_minor = get_binding_version() - _driver_ver = handle_return(driver.cuDriverGetVersion()) - _inited = True - - -@dataclass -class GraphDebugPrintOptions: - """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` - - Attributes - ---------- - verbose : bool - Output all debug data as if every debug flag is enabled (Default to False) - runtime_types : bool - Use CUDA Runtime structures for output (Default to False) - kernel_node_params : bool - Adds kernel parameter values to output (Default to False) - memcpy_node_params : bool - Adds memcpy parameter values to output (Default to False) - memset_node_params : bool - Adds memset parameter values to output (Default to False) - host_node_params : bool - Adds host parameter values to output (Default to False) - event_node_params : bool - Adds event parameter values to output (Default to False) - ext_semas_signal_node_params : bool - Adds external semaphore signal parameter values to output (Default to False) - ext_semas_wait_node_params : bool - Adds external semaphore wait parameter values to output (Default to False) - kernel_node_attributes : bool - Adds kernel node attributes to output (Default to False) - handles : bool - Adds node handles and every kernel function handle to output (Default to False) - mem_alloc_node_params : bool - Adds memory alloc parameter values to output (Default to False) - mem_free_node_params : bool - Adds memory free parameter values to output (Default to False) - batch_mem_op_node_params : bool - Adds batch mem op parameter values to output (Default to False) - extra_topo_info : bool - Adds edge numbering information (Default to False) - conditional_node_params : bool - Adds conditional node parameter values to output (Default to False) - - """ - - verbose: bool = False - runtime_types: bool = False - kernel_node_params: bool = False - memcpy_node_params: bool = False - memset_node_params: bool = False - host_node_params: bool = False - event_node_params: bool = False - ext_semas_signal_node_params: bool = False - ext_semas_wait_node_params: bool = False - kernel_node_attributes: bool = False - handles: bool = False - mem_alloc_node_params: bool = False - mem_free_node_params: bool = False - batch_mem_op_node_params: bool = False - extra_topo_info: bool = False - conditional_node_params: bool = False - - def _to_flags(self) -> int: - """Convert options to CUDA driver API flags (internal use).""" - flags = 0 - if self.verbose: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE - if self.runtime_types: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES - if self.kernel_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS - if self.memcpy_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS - if self.memset_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS - if self.host_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS - if self.event_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS - if self.ext_semas_signal_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS - if self.ext_semas_wait_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS - if self.kernel_node_attributes: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES - if self.handles: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES - if self.mem_alloc_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS - if self.mem_free_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS - if self.batch_mem_op_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS - if self.extra_topo_info: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO - if self.conditional_node_params: - flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS - return flags - - -@dataclass -class GraphCompleteOptions: - """Customizable options for :obj:`_graph.GraphBuilder.complete()` - - Attributes - ---------- - auto_free_on_launch : bool, optional - Automatically free memory allocated in a graph before relaunching. (Default to False) - upload_stream : Stream, optional - Stream to use to automatically upload the graph after completion. (Default to None) - device_launch : bool, optional - Configure the graph to be launchable from the device. This flag can only - be used on platforms which support unified addressing. This flag cannot be - used in conjunction with auto_free_on_launch. (Default to False) - use_node_priority : bool, optional - Run the graph using the per-node priority attributes rather than the - priority of the stream it is launched into. (Default to False) - - """ - - auto_free_on_launch: bool = False - upload_stream: Stream | None = None - device_launch: bool = False - use_node_priority: bool = False - - -def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: - params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() - if options: - flags = 0 - if options.auto_free_on_launch: - flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH - if options.upload_stream: - flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD - params.hUploadStream = options.upload_stream.handle - if options.device_launch: - flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH - if options.use_node_priority: - flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY - params.flags = flags - - graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) - if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: - raise RuntimeError( - "Instantiation failed for an unexpected reason which is described in the return value of the function." - ) - elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: - raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") - elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: - raise RuntimeError( - "Instantiation for device launch failed because the graph contained an unsupported operation." - ) - elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: - raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") - elif ( - _py_major_minor >= (12, 8) - and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED - ): - raise RuntimeError("One or more conditional handles are not associated with conditional builders.") - elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: - raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") - return graph - - -class GraphBuilder: - """Represents a graph under construction. - - A graph groups a set of CUDA kernels and other CUDA operations together and executes - them with a specified dependency tree. It speeds up the workflow by combining the - driver activities associated with CUDA kernel launches and CUDA API calls. - - Directly creating a :obj:`~_graph.GraphBuilder` is not supported due - to ambiguity. New graph builders should instead be created through a - :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. - - """ - - class _MembersNeededForFinalize: - __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") - - def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): - self.stream = stream_obj - self.is_stream_owner = is_stream_owner - self.graph = None - self.conditional_graph = conditional_graph - self.is_join_required = is_join_required - weakref.finalize(graph_builder_obj, self.close) - - def close(self): - if self.stream: - if not self.is_join_required: - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] - if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: - # Note how this condition only occures for the primary graph builder - # This is because calling cuStreamEndCapture streams that were split off of the primary - # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. - # Therefore, it is currently a requirement that users join all split graph builders - # before a graph builder can be clearly destroyed. - handle_return(driver.cuStreamEndCapture(self.stream.handle)) - if self.is_stream_owner: - self.stream.close() - self.stream = None - if self.graph: - handle_return(driver.cuGraphDestroy(self.graph)) - self.graph = None - self.conditional_graph = None - - __slots__ = ("__weakref__", "_building_ended", "_mnff") - - def __init__(self): - raise NotImplementedError( - "directly creating a Graph object can be ambiguous. Please either " - "call Device.create_graph_builder() or stream.create_graph_builder()" - ) - - @classmethod - def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): - self = cls.__new__(cls) - _lazy_init() - self._mnff = GraphBuilder._MembersNeededForFinalize( - self, stream, is_stream_owner, conditional_graph, is_join_required - ) - - self._building_ended = False - return self - - @property - def stream(self) -> Stream: - """Returns the stream associated with the graph builder.""" - return self._mnff.stream - - @property - def is_join_required(self) -> bool: - """Returns True if this graph builder must be joined before building is ended.""" - return self._mnff.is_join_required - - def begin_building(self, mode="relaxed") -> GraphBuilder: - """Begins the building process. - - Build `mode` for controlling interaction with other API calls must be one of the following: - - - `global` : Prohibit potentially unsafe operations across all streams in the process. - - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. - - `relaxed` : The local thread is not prohibited from potentially unsafe operations. - - Parameters - ---------- - mode : str, optional - Build mode to control the interaction with other API calls that are porentially unsafe. - Default set to use relaxed. - - """ - if self._building_ended: - raise RuntimeError("Cannot resume building after building has ended.") - if mode not in ("global", "thread_local", "relaxed"): - raise ValueError(f"Unsupported build mode: {mode}") - if mode == "global": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL - elif mode == "thread_local": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL - elif mode == "relaxed": - capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED - else: - raise ValueError(f"Unsupported build mode: {mode}") - - if self._mnff.conditional_graph: - handle_return( - driver.cuStreamBeginCaptureToGraph( - self._mnff.stream.handle, - self._mnff.conditional_graph, - None, # dependencies - None, # dependencyData - 0, # numDependencies - capture_mode, - ) - ) - else: - handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) - return self - - @property - def is_building(self) -> bool: - """Returns True if the graph builder is currently building.""" - capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] - if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: - return False - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - return True - elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: - raise RuntimeError( - "Build process encountered an error and has been invalidated. Build process must now be ended." - ) - else: - raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") - - def end_building(self) -> GraphBuilder: - """Ends the building process.""" - if not self.is_building: - raise RuntimeError("Graph builder is not building.") - if self._mnff.conditional_graph: - self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) - else: - self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) - - # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to - # resume the build process after the first call to end_building() - self._building_ended = True - return self - - def complete(self, options: GraphCompleteOptions | None = None) -> Graph: - """Completes the graph builder and returns the built :obj:`~_graph.Graph` object. - - Parameters - ---------- - options : :obj:`~_graph.GraphCompleteOptions`, optional - Customizable dataclass for the graph builder completion options. - - Returns - ------- - graph : :obj:`~_graph.Graph` - The newly built graph. - - """ - if not self._building_ended: - raise RuntimeError("Graph has not finished building.") - - return _instantiate_graph(self._mnff.graph, options) - - def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): - """Generates a DOT debug file for the graph builder. - - Parameters - ---------- - path : str - File path to use for writting debug DOT output - options : :obj:`~_graph.GraphDebugPrintOptions`, optional - Customizable dataclass for the debug print options. - - """ - if not self._building_ended: - raise RuntimeError("Graph has not finished building.") - flags = options._to_flags() if options else 0 - handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) - - def split(self, count: int) -> tuple[GraphBuilder, ...]: - """Splits the original graph builder into multiple graph builders. - - The new builders inherit work dependencies from the original builder. - The original builder is reused for the split and is returned first in the tuple. - - Parameters - ---------- - count : int - The number of graph builders to split the graph builder into. - - Returns - ------- - graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] - A tuple of split graph builders. The first graph builder in the tuple - is always the original graph builder. - - """ - if count < 2: - raise ValueError(f"Invalid split count: expecting >= 2, got {count}") - - event = self._mnff.stream.record() - result = [self] - for i in range(count - 1): - stream = self._mnff.stream.device.create_stream() - stream.wait(event) - result.append( - GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) - ) - event.close() - return result - - @staticmethod - def join(*graph_builders) -> GraphBuilder: - """Joins multiple graph builders into a single graph builder. - - The returned builder inherits work dependencies from the provided builders. - - Parameters - ---------- - *graph_builders : :obj:`~_graph.GraphBuilder` - The graph builders to join. - - Returns - ------- - graph_builder : :obj:`~_graph.GraphBuilder` - The newly joined graph builder. - - """ - if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): - raise TypeError("All arguments must be GraphBuilder instances") - if len(graph_builders) < 2: - raise ValueError("Must join with at least two graph builders") - - # Discover the root builder others should join - root_idx = 0 - for i, builder in enumerate(graph_builders): - if not builder.is_join_required: - root_idx = i - break - - # Join all onto the root builder - root_bdr = graph_builders[root_idx] - for idx, builder in enumerate(graph_builders): - if idx == root_idx: - continue - root_bdr.stream.wait(builder.stream) - builder.close() - - return root_bdr - - def __cuda_stream__(self) -> tuple[int, int]: - """Return an instance of a __cuda_stream__ protocol.""" - return self.stream.__cuda_stream__() - - def _get_conditional_context(self) -> driver.CUcontext: - return self._mnff.stream.context.handle - - def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: - """Creates a conditional handle for the graph builder. - - Parameters - ---------- - default_value : int, optional - The default value to assign to the conditional handle. - - Returns - ------- - handle : driver.CUgraphConditionalHandle - The newly created conditional handle. - - """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") - if default_value is not None: - flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT - else: - default_value = 0 - flags = 0 - - status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) - if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - raise RuntimeError("Cannot create a conditional handle when graph is not being built") - - return handle_return( - driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) - ) - - def _cond_with_params(self, node_params) -> GraphBuilder: - # Get current capture info to ensure we're in a valid state - status, _, graph, *deps_info, num_dependencies = handle_return( - driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) - ) - if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: - raise RuntimeError("Cannot add conditional node when not actively capturing") - - # Add the conditional node to the graph - deps_info_update = [ - [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] - ] + [None] * (len(deps_info) - 1) - - # Update the stream's capture dependencies - handle_return( - driver.cuStreamUpdateCaptureDependencies( - self._mnff.stream.handle, - *deps_info_update, # dependencies, edgeData - 1, # numDependencies - driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, - ) - ) - - # Create new graph builders for each condition - return tuple( - [ - GraphBuilder._init( - stream=self._mnff.stream.device.create_stream(), - is_stream_owner=True, - conditional_graph=node_params.conditional.phGraph_out[i], - is_join_required=False, - ) - for i in range(node_params.conditional.size) - ] - ) - - def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: - """Adds an if condition branch and returns a new graph builder for it. - - The resulting if graph will only execute the branch if the conditional - handle evaluates to true at runtime. - - The new builder inherits work dependencies from the original builder. - - Parameters - ---------- - handle : driver.CUgraphConditionalHandle - The handle to use for the if conditional. - - Returns - ------- - graph_builder : :obj:`~_graph.GraphBuilder` - The newly created conditional graph builder. - - """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") - node_params = driver.CUgraphNodeParams() - node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL - node_params.conditional.handle = handle - node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF - node_params.conditional.size = 1 - node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] - - def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: - """Adds an if-else condition branch and returns new graph builders for both branches. - - The resulting if graph will execute the branch if the conditional handle - evaluates to true at runtime, otherwise the else branch will execute. - - The new builders inherit work dependencies from the original builder. - - Parameters - ---------- - handle : driver.CUgraphConditionalHandle - The handle to use for the if-else conditional. - - Returns - ------- - graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`] - A tuple of two new graph builders, one for the if branch and one for the else branch. - - """ - if _driver_ver < 12080: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") - if _py_major_minor < (12, 8): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") - node_params = driver.CUgraphNodeParams() - node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL - node_params.conditional.handle = handle - node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF - node_params.conditional.size = 2 - node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) - - def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: - """Adds a switch condition branch and returns new graph builders for all cases. - - The resulting switch graph will execute the branch that matches the - case index of the conditional handle at runtime. If no match is found, no branch - will be executed. - - The new builders inherit work dependencies from the original builder. - - Parameters - ---------- - handle : driver.CUgraphConditionalHandle - The handle to use for the switch conditional. - count : int - The number of cases to add to the switch conditional. - - Returns - ------- - graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] - A tuple of new graph builders, one for each branch. - - """ - if _driver_ver < 12080: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") - if _py_major_minor < (12, 8): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") - node_params = driver.CUgraphNodeParams() - node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL - node_params.conditional.handle = handle - node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH - node_params.conditional.size = count - node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params) - - def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: - """Adds a while loop and returns a new graph builder for it. - - The resulting while loop graph will execute the branch repeatedly at runtime - until the conditional handle evaluates to false. - - The new builder inherits work dependencies from the original builder. - - Parameters - ---------- - handle : driver.CUgraphConditionalHandle - The handle to use for the while loop. - - Returns - ------- - graph_builder : :obj:`~_graph.GraphBuilder` - The newly created while loop graph builder. - - """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") - node_params = driver.CUgraphNodeParams() - node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL - node_params.conditional.handle = handle - node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE - node_params.conditional.size = 1 - node_params.conditional.ctx = self._get_conditional_context() - return self._cond_with_params(node_params)[0] - - def close(self): - """Destroy the graph builder. - - Closes the associated stream if we own it. Borrowed stream - object will instead have their references released. - - """ - self._mnff.close() - - def add_child(self, child_graph: GraphBuilder): - """Adds the child :obj:`~_graph.GraphBuilder` builder into self. - - The child graph builder will be added as a child node to the parent graph builder. - - Parameters - ---------- - child_graph : :obj:`~_graph.GraphBuilder` - The child graph builder. Must have finished building. - """ - if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): - raise NotImplementedError( - f"Launching child graphs is not implemented for versions older than CUDA 12." - f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" - ) - - if not child_graph._building_ended: - raise ValueError("Child graph has not finished building.") - - if not self.is_building: - raise ValueError("Parent graph is not being built.") - - stream_handle = self._mnff.stream.handle - _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( - driver.cuStreamGetCaptureInfo(stream_handle) - ) - - # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 - # for rationale - deps_info_trimmed = deps_info_out[:num_dependencies_out] - deps_info_update = [ - [ - handle_return( - driver.cuGraphAddChildGraphNode( - graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph - ) - ) - ] - ] + [None] * (len(deps_info_out) - 1) - handle_return( - driver.cuStreamUpdateCaptureDependencies( - stream_handle, - *deps_info_update, # dependencies, edgeData - 1, - driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, - ) - ) - - -class Graph: - """Represents an executable graph. - - A graph groups a set of CUDA kernels and other CUDA operations together and executes - them with a specified dependency tree. It speeds up the workflow by combining the - driver activities associated with CUDA kernel launches and CUDA API calls. - - Graphs must be built using a :obj:`~_graph.GraphBuilder` object. - - """ - - class _MembersNeededForFinalize: - __slots__ = "graph" - - def __init__(self, graph_obj, graph): - self.graph = graph - weakref.finalize(graph_obj, self.close) - - def close(self): - if self.graph: - handle_return(driver.cuGraphExecDestroy(self.graph)) - self.graph = None - - __slots__ = ("__weakref__", "_mnff") - - def __init__(self): - raise RuntimeError("directly constructing a Graph instance is not supported") - - @classmethod - def _init(cls, graph): - self = cls.__new__(cls) - self._mnff = Graph._MembersNeededForFinalize(self, graph) - return self - - def close(self): - """Destroy the graph.""" - self._mnff.close() - - @property - def handle(self) -> driver.CUgraphExec: - """Return the underlying ``CUgraphExec`` object. - - .. caution:: - - This handle is a Python object. To get the memory address of the underlying C - handle, call ``int()`` on the returned object. - - """ - return self._mnff.graph - - def update(self, builder: GraphBuilder): - """Update the graph using new build configuration from the builder. - - The topology of the provided builder must be identical to this graph. - - Parameters - ---------- - builder : :obj:`~_graph.GraphBuilder` - The builder to update the graph with. - - """ - if not builder._building_ended: - raise ValueError("Graph has not finished building.") - - # Update the graph with the new nodes from the builder - exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) - if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: - raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") - - def upload(self, stream: Stream): - """Uploads the graph in a stream. - - Parameters - ---------- - stream : :obj:`~_stream.Stream` - The stream in which to upload the graph - - """ - handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) - - def launch(self, stream: Stream): - """Launches the graph in a stream. - - Parameters - ---------- - stream : :obj:`~_stream.Stream` - The stream in which to launch the graph - - """ - handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) +__all__ = [ + "Graph", + "GraphBuilder", + "GraphCompleteOptions", + "GraphDebugPrintOptions", + "_instantiate_graph", +] diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx new file mode 100644 index 0000000000..a0141be23c --- /dev/null +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -0,0 +1,791 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import weakref +from dataclasses import dataclass + +from cuda.core._stream cimport Stream +from cuda.core._utils.cuda_utils import ( + driver, + get_binding_version, + handle_return, +) + +_inited = False +_driver_ver = None + + +def _lazy_init(): + global _inited + if _inited: + return + + global _py_major_minor, _driver_ver + # binding availability depends on cuda-python version + _py_major_minor = get_binding_version() + _driver_ver = handle_return(driver.cuDriverGetVersion()) + _inited = True + + +@dataclass +class GraphDebugPrintOptions: + """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` + + Attributes + ---------- + verbose : bool + Output all debug data as if every debug flag is enabled (Default to False) + runtime_types : bool + Use CUDA Runtime structures for output (Default to False) + kernel_node_params : bool + Adds kernel parameter values to output (Default to False) + memcpy_node_params : bool + Adds memcpy parameter values to output (Default to False) + memset_node_params : bool + Adds memset parameter values to output (Default to False) + host_node_params : bool + Adds host parameter values to output (Default to False) + event_node_params : bool + Adds event parameter values to output (Default to False) + ext_semas_signal_node_params : bool + Adds external semaphore signal parameter values to output (Default to False) + ext_semas_wait_node_params : bool + Adds external semaphore wait parameter values to output (Default to False) + kernel_node_attributes : bool + Adds kernel node attributes to output (Default to False) + handles : bool + Adds node handles and every kernel function handle to output (Default to False) + mem_alloc_node_params : bool + Adds memory alloc parameter values to output (Default to False) + mem_free_node_params : bool + Adds memory free parameter values to output (Default to False) + batch_mem_op_node_params : bool + Adds batch mem op parameter values to output (Default to False) + extra_topo_info : bool + Adds edge numbering information (Default to False) + conditional_node_params : bool + Adds conditional node parameter values to output (Default to False) + + """ + + verbose: bool = False + runtime_types: bool = False + kernel_node_params: bool = False + memcpy_node_params: bool = False + memset_node_params: bool = False + host_node_params: bool = False + event_node_params: bool = False + ext_semas_signal_node_params: bool = False + ext_semas_wait_node_params: bool = False + kernel_node_attributes: bool = False + handles: bool = False + mem_alloc_node_params: bool = False + mem_free_node_params: bool = False + batch_mem_op_node_params: bool = False + extra_topo_info: bool = False + conditional_node_params: bool = False + + def _to_flags(self) -> int: + """Convert options to CUDA driver API flags (internal use).""" + flags = 0 + if self.verbose: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE + if self.runtime_types: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES + if self.kernel_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS + if self.memcpy_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS + if self.memset_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS + if self.host_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS + if self.event_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS + if self.ext_semas_signal_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS + if self.ext_semas_wait_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS + if self.kernel_node_attributes: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES + if self.handles: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES + if self.mem_alloc_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS + if self.mem_free_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS + if self.batch_mem_op_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS + if self.extra_topo_info: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO + if self.conditional_node_params: + flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS + return flags + + +@dataclass +class GraphCompleteOptions: + """Customizable options for :obj:`_graph.GraphBuilder.complete()` + + Attributes + ---------- + auto_free_on_launch : bool, optional + Automatically free memory allocated in a graph before relaunching. (Default to False) + upload_stream : Stream, optional + Stream to use to automatically upload the graph after completion. (Default to None) + device_launch : bool, optional + Configure the graph to be launchable from the device. This flag can only + be used on platforms which support unified addressing. This flag cannot be + used in conjunction with auto_free_on_launch. (Default to False) + use_node_priority : bool, optional + Run the graph using the per-node priority attributes rather than the + priority of the stream it is launched into. (Default to False) + + """ + + auto_free_on_launch: bool = False + upload_stream: Stream | None = None + device_launch: bool = False + use_node_priority: bool = False + + +def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph: + params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS() + if options: + flags = 0 + if options.auto_free_on_launch: + flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH + if options.upload_stream: + flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD + params.hUploadStream = options.upload_stream.handle + if options.device_launch: + flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH + if options.use_node_priority: + flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY + params.flags = flags + + graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params))) + if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR: + raise RuntimeError( + "Instantiation failed for an unexpected reason which is described in the return value of the function." + ) + elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE: + raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.") + elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED: + raise RuntimeError( + "Instantiation for device launch failed because the graph contained an unsupported operation." + ) + elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: + raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") + elif ( + _py_major_minor >= (12, 8) + and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED + ): + raise RuntimeError("One or more conditional handles are not associated with conditional builders.") + elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS: + raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}") + return graph + + +class GraphBuilder: + """Represents a graph under construction. + + A graph groups a set of CUDA kernels and other CUDA operations together and executes + them with a specified dependency tree. It speeds up the workflow by combining the + driver activities associated with CUDA kernel launches and CUDA API calls. + + Directly creating a :obj:`~_graph.GraphBuilder` is not supported due + to ambiguity. New graph builders should instead be created through a + :obj:`~_device.Device`, or a :obj:`~_stream.stream` object. + + """ + + class _MembersNeededForFinalize: + __slots__ = ("conditional_graph", "graph", "is_join_required", "is_stream_owner", "stream") + + def __init__(self, graph_builder_obj, stream_obj, is_stream_owner, conditional_graph, is_join_required): + self.stream = stream_obj + self.is_stream_owner = is_stream_owner + self.graph = None + self.conditional_graph = conditional_graph + self.is_join_required = is_join_required + weakref.finalize(graph_builder_obj, self.close) + + def close(self): + if self.stream: + if not self.is_join_required: + capture_status = handle_return(driver.cuStreamGetCaptureInfo(self.stream.handle))[0] + if capture_status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: + # Note how this condition only occures for the primary graph builder + # This is because calling cuStreamEndCapture streams that were split off of the primary + # would error out with CUDA_ERROR_STREAM_CAPTURE_UNJOINED. + # Therefore, it is currently a requirement that users join all split graph builders + # before a graph builder can be clearly destroyed. + handle_return(driver.cuStreamEndCapture(self.stream.handle)) + if self.is_stream_owner: + self.stream.close() + self.stream = None + if self.graph: + handle_return(driver.cuGraphDestroy(self.graph)) + self.graph = None + self.conditional_graph = None + + __slots__ = ("__weakref__", "_building_ended", "_mnff") + + def __init__(self): + raise NotImplementedError( + "directly creating a Graph object can be ambiguous. Please either " + "call Device.create_graph_builder() or stream.create_graph_builder()" + ) + + @classmethod + def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): + self = cls.__new__(cls) + _lazy_init() + self._mnff = GraphBuilder._MembersNeededForFinalize( + self, stream, is_stream_owner, conditional_graph, is_join_required + ) + + self._building_ended = False + return self + + @property + def stream(self) -> Stream: + """Returns the stream associated with the graph builder.""" + return self._mnff.stream + + @property + def is_join_required(self) -> bool: + """Returns True if this graph builder must be joined before building is ended.""" + return self._mnff.is_join_required + + def begin_building(self, mode="relaxed") -> GraphBuilder: + """Begins the building process. + + Build `mode` for controlling interaction with other API calls must be one of the following: + + - `global` : Prohibit potentially unsafe operations across all streams in the process. + - `thread_local` : Prohibit potentially unsafe operations in streams created by the current thread. + - `relaxed` : The local thread is not prohibited from potentially unsafe operations. + + Parameters + ---------- + mode : str, optional + Build mode to control the interaction with other API calls that are porentially unsafe. + Default set to use relaxed. + + """ + if self._building_ended: + raise RuntimeError("Cannot resume building after building has ended.") + if mode not in ("global", "thread_local", "relaxed"): + raise ValueError(f"Unsupported build mode: {mode}") + if mode == "global": + capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_GLOBAL + elif mode == "thread_local": + capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_THREAD_LOCAL + elif mode == "relaxed": + capture_mode = driver.CUstreamCaptureMode.CU_STREAM_CAPTURE_MODE_RELAXED + else: + raise ValueError(f"Unsupported build mode: {mode}") + + if self._mnff.conditional_graph: + handle_return( + driver.cuStreamBeginCaptureToGraph( + self._mnff.stream.handle, + self._mnff.conditional_graph, + None, # dependencies + None, # dependencyData + 0, # numDependencies + capture_mode, + ) + ) + else: + handle_return(driver.cuStreamBeginCapture(self._mnff.stream.handle, capture_mode)) + return self + + @property + def is_building(self) -> bool: + """Returns True if the graph builder is currently building.""" + capture_status = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle))[0] + if capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_NONE: + return False + elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + return True + elif capture_status == driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_INVALIDATED: + raise RuntimeError( + "Build process encountered an error and has been invalidated. Build process must now be ended." + ) + else: + raise NotImplementedError(f"Unsupported capture status type received: {capture_status}") + + def end_building(self) -> GraphBuilder: + """Ends the building process.""" + if not self.is_building: + raise RuntimeError("Graph builder is not building.") + if self._mnff.conditional_graph: + self._mnff.conditional_graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) + else: + self._mnff.graph = handle_return(driver.cuStreamEndCapture(self.stream.handle)) + + # TODO: Resolving https://github.com/NVIDIA/cuda-python/issues/617 would allow us to + # resume the build process after the first call to end_building() + self._building_ended = True + return self + + def complete(self, options: GraphCompleteOptions | None = None) -> Graph: + """Completes the graph builder and returns the built :obj:`~_graph.Graph` object. + + Parameters + ---------- + options : :obj:`~_graph.GraphCompleteOptions`, optional + Customizable dataclass for the graph builder completion options. + + Returns + ------- + graph : :obj:`~_graph.Graph` + The newly built graph. + + """ + if not self._building_ended: + raise RuntimeError("Graph has not finished building.") + + return _instantiate_graph(self._mnff.graph, options) + + def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None): + """Generates a DOT debug file for the graph builder. + + Parameters + ---------- + path : str + File path to use for writting debug DOT output + options : :obj:`~_graph.GraphDebugPrintOptions`, optional + Customizable dataclass for the debug print options. + + """ + if not self._building_ended: + raise RuntimeError("Graph has not finished building.") + flags = options._to_flags() if options else 0 + handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags)) + + def split(self, count: int) -> tuple[GraphBuilder, ...]: + """Splits the original graph builder into multiple graph builders. + + The new builders inherit work dependencies from the original builder. + The original builder is reused for the split and is returned first in the tuple. + + Parameters + ---------- + count : int + The number of graph builders to split the graph builder into. + + Returns + ------- + graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] + A tuple of split graph builders. The first graph builder in the tuple + is always the original graph builder. + + """ + if count < 2: + raise ValueError(f"Invalid split count: expecting >= 2, got {count}") + + event = self._mnff.stream.record() + result = [self] + for i in range(count - 1): + stream = self._mnff.stream.device.create_stream() + stream.wait(event) + result.append( + GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) + ) + event.close() + return result + + @staticmethod + def join(*graph_builders) -> GraphBuilder: + """Joins multiple graph builders into a single graph builder. + + The returned builder inherits work dependencies from the provided builders. + + Parameters + ---------- + *graph_builders : :obj:`~_graph.GraphBuilder` + The graph builders to join. + + Returns + ------- + graph_builder : :obj:`~_graph.GraphBuilder` + The newly joined graph builder. + + """ + if any(not isinstance(builder, GraphBuilder) for builder in graph_builders): + raise TypeError("All arguments must be GraphBuilder instances") + if len(graph_builders) < 2: + raise ValueError("Must join with at least two graph builders") + + # Discover the root builder others should join + root_idx = 0 + for i, builder in enumerate(graph_builders): + if not builder.is_join_required: + root_idx = i + break + + # Join all onto the root builder + root_bdr = graph_builders[root_idx] + for idx, builder in enumerate(graph_builders): + if idx == root_idx: + continue + root_bdr.stream.wait(builder.stream) + builder.close() + + return root_bdr + + def __cuda_stream__(self) -> tuple[int, int]: + """Return an instance of a __cuda_stream__ protocol.""" + return self.stream.__cuda_stream__() + + def _get_conditional_context(self) -> driver.CUcontext: + return self._mnff.stream.context.handle + + def create_conditional_handle(self, default_value=None) -> driver.CUgraphConditionalHandle: + """Creates a conditional handle for the graph builder. + + Parameters + ---------- + default_value : int, optional + The default value to assign to the conditional handle. + + Returns + ------- + handle : driver.CUgraphConditionalHandle + The newly created conditional handle. + + """ + if _driver_ver < 12030: + raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") + if _py_major_minor < (12, 3): + raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") + if default_value is not None: + flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT + else: + default_value = 0 + flags = 0 + + status, _, graph, *_, _ = handle_return(driver.cuStreamGetCaptureInfo(self._mnff.stream.handle)) + if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot create a conditional handle when graph is not being built") + + return handle_return( + driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) + ) + + def _cond_with_params(self, node_params) -> GraphBuilder: + # Get current capture info to ensure we're in a valid state + status, _, graph, *deps_info, num_dependencies = handle_return( + driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) + ) + if status != driver.CUstreamCaptureStatus.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot add conditional node when not actively capturing") + + # Add the conditional node to the graph + deps_info_update = [ + [handle_return(driver.cuGraphAddNode(graph, *deps_info, num_dependencies, node_params))] + ] + [None] * (len(deps_info) - 1) + + # Update the stream's capture dependencies + handle_return( + driver.cuStreamUpdateCaptureDependencies( + self._mnff.stream.handle, + *deps_info_update, # dependencies, edgeData + 1, # numDependencies + driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, + ) + ) + + # Create new graph builders for each condition + return tuple( + [ + GraphBuilder._init( + stream=self._mnff.stream.device.create_stream(), + is_stream_owner=True, + conditional_graph=node_params.conditional.phGraph_out[i], + is_join_required=False, + ) + for i in range(node_params.conditional.size) + ] + ) + + def if_cond(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: + """Adds an if condition branch and returns a new graph builder for it. + + The resulting if graph will only execute the branch if the conditional + handle evaluates to true at runtime. + + The new builder inherits work dependencies from the original builder. + + Parameters + ---------- + handle : driver.CUgraphConditionalHandle + The handle to use for the if conditional. + + Returns + ------- + graph_builder : :obj:`~_graph.GraphBuilder` + The newly created conditional graph builder. + + """ + if _driver_ver < 12030: + raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") + if _py_major_minor < (12, 3): + raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") + node_params = driver.CUgraphNodeParams() + node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL + node_params.conditional.handle = handle + node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF + node_params.conditional.size = 1 + node_params.conditional.ctx = self._get_conditional_context() + return self._cond_with_params(node_params)[0] + + def if_else(self, handle: driver.CUgraphConditionalHandle) -> tuple[GraphBuilder, GraphBuilder]: + """Adds an if-else condition branch and returns new graph builders for both branches. + + The resulting if graph will execute the branch if the conditional handle + evaluates to true at runtime, otherwise the else branch will execute. + + The new builders inherit work dependencies from the original builder. + + Parameters + ---------- + handle : driver.CUgraphConditionalHandle + The handle to use for the if-else conditional. + + Returns + ------- + graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`] + A tuple of two new graph builders, one for the if branch and one for the else branch. + + """ + if _driver_ver < 12080: + raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") + if _py_major_minor < (12, 8): + raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") + node_params = driver.CUgraphNodeParams() + node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL + node_params.conditional.handle = handle + node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_IF + node_params.conditional.size = 2 + node_params.conditional.ctx = self._get_conditional_context() + return self._cond_with_params(node_params) + + def switch(self, handle: driver.CUgraphConditionalHandle, count: int) -> tuple[GraphBuilder, ...]: + """Adds a switch condition branch and returns new graph builders for all cases. + + The resulting switch graph will execute the branch that matches the + case index of the conditional handle at runtime. If no match is found, no branch + will be executed. + + The new builders inherit work dependencies from the original builder. + + Parameters + ---------- + handle : driver.CUgraphConditionalHandle + The handle to use for the switch conditional. + count : int + The number of cases to add to the switch conditional. + + Returns + ------- + graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...] + A tuple of new graph builders, one for each branch. + + """ + if _driver_ver < 12080: + raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") + if _py_major_minor < (12, 8): + raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") + node_params = driver.CUgraphNodeParams() + node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL + node_params.conditional.handle = handle + node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_SWITCH + node_params.conditional.size = count + node_params.conditional.ctx = self._get_conditional_context() + return self._cond_with_params(node_params) + + def while_loop(self, handle: driver.CUgraphConditionalHandle) -> GraphBuilder: + """Adds a while loop and returns a new graph builder for it. + + The resulting while loop graph will execute the branch repeatedly at runtime + until the conditional handle evaluates to false. + + The new builder inherits work dependencies from the original builder. + + Parameters + ---------- + handle : driver.CUgraphConditionalHandle + The handle to use for the while loop. + + Returns + ------- + graph_builder : :obj:`~_graph.GraphBuilder` + The newly created while loop graph builder. + + """ + if _driver_ver < 12030: + raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") + if _py_major_minor < (12, 3): + raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") + node_params = driver.CUgraphNodeParams() + node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL + node_params.conditional.handle = handle + node_params.conditional.type = driver.CUgraphConditionalNodeType.CU_GRAPH_COND_TYPE_WHILE + node_params.conditional.size = 1 + node_params.conditional.ctx = self._get_conditional_context() + return self._cond_with_params(node_params)[0] + + def close(self): + """Destroy the graph builder. + + Closes the associated stream if we own it. Borrowed stream + object will instead have their references released. + + """ + self._mnff.close() + + def add_child(self, child_graph: GraphBuilder): + """Adds the child :obj:`~_graph.GraphBuilder` builder into self. + + The child graph builder will be added as a child node to the parent graph builder. + + Parameters + ---------- + child_graph : :obj:`~_graph.GraphBuilder` + The child graph builder. Must have finished building. + """ + if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): + raise NotImplementedError( + f"Launching child graphs is not implemented for versions older than CUDA 12." + f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" + ) + + if not child_graph._building_ended: + raise ValueError("Child graph has not finished building.") + + if not self.is_building: + raise ValueError("Parent graph is not being built.") + + stream_handle = self._mnff.stream.handle + _, _, graph_out, *deps_info_out, num_dependencies_out = handle_return( + driver.cuStreamGetCaptureInfo(stream_handle) + ) + + # See https://github.com/NVIDIA/cuda-python/pull/879#issuecomment-3211054159 + # for rationale + deps_info_trimmed = deps_info_out[:num_dependencies_out] + deps_info_update = [ + [ + handle_return( + driver.cuGraphAddChildGraphNode( + graph_out, *deps_info_trimmed, num_dependencies_out, child_graph._mnff.graph + ) + ) + ] + ] + [None] * (len(deps_info_out) - 1) + handle_return( + driver.cuStreamUpdateCaptureDependencies( + stream_handle, + *deps_info_update, # dependencies, edgeData + 1, + driver.CUstreamUpdateCaptureDependencies_flags.CU_STREAM_SET_CAPTURE_DEPENDENCIES, + ) + ) + + +class Graph: + """Represents an executable graph. + + A graph groups a set of CUDA kernels and other CUDA operations together and executes + them with a specified dependency tree. It speeds up the workflow by combining the + driver activities associated with CUDA kernel launches and CUDA API calls. + + Graphs must be built using a :obj:`~_graph.GraphBuilder` object. + + """ + + class _MembersNeededForFinalize: + __slots__ = "graph" + + def __init__(self, graph_obj, graph): + self.graph = graph + weakref.finalize(graph_obj, self.close) + + def close(self): + if self.graph: + handle_return(driver.cuGraphExecDestroy(self.graph)) + self.graph = None + + __slots__ = ("__weakref__", "_mnff") + + def __init__(self): + raise RuntimeError("directly constructing a Graph instance is not supported") + + @classmethod + def _init(cls, graph): + self = cls.__new__(cls) + self._mnff = Graph._MembersNeededForFinalize(self, graph) + return self + + def close(self): + """Destroy the graph.""" + self._mnff.close() + + @property + def handle(self) -> driver.CUgraphExec: + """Return the underlying ``CUgraphExec`` object. + + .. caution:: + + This handle is a Python object. To get the memory address of the underlying C + handle, call ``int()`` on the returned object. + + """ + return self._mnff.graph + + def update(self, builder: GraphBuilder): + """Update the graph using new build configuration from the builder. + + The topology of the provided builder must be identical to this graph. + + Parameters + ---------- + builder : :obj:`~_graph.GraphBuilder` + The builder to update the graph with. + + """ + if not builder._building_ended: + raise ValueError("Graph has not finished building.") + + # Update the graph with the new nodes from the builder + exec_update_result = handle_return(driver.cuGraphExecUpdate(self._mnff.graph, builder._mnff.graph)) + if exec_update_result.result != driver.CUgraphExecUpdateResult.CU_GRAPH_EXEC_UPDATE_SUCCESS: + raise RuntimeError(f"Failed to update graph: {exec_update_result.result()}") + + def upload(self, stream: Stream): + """Uploads the graph in a stream. + + Parameters + ---------- + stream : :obj:`~_stream.Stream` + The stream in which to upload the graph + + """ + handle_return(driver.cuGraphUpload(self._mnff.graph, stream.handle)) + + def launch(self, stream: Stream): + """Launches the graph in a stream. + + Parameters + ---------- + stream : :obj:`~_stream.Stream` + The stream in which to launch the graph + + """ + handle_return(driver.cuGraphLaunch(self._mnff.graph, stream.handle)) diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx index 107d0ecc8b..cad2523a4a 100644 --- a/cuda_core/cuda/core/_graph/_graphdef.pyx +++ b/cuda_core/cuda/core/_graph/_graphdef.pyx @@ -470,7 +470,7 @@ cdef class GraphDef: Graph An executable graph that can be launched on a stream. """ - from cuda.core._graph import _instantiate_graph + from cuda.core._graph._graph_builder import _instantiate_graph return _instantiate_graph( driver.CUgraph(as_intptr(self._h_graph)), options) @@ -485,7 +485,7 @@ cdef class GraphDef: options : GraphDebugPrintOptions, optional Customizable options for the debug print. """ - from cuda.core._graph import GraphDebugPrintOptions + from cuda.core._graph._graph_builder import GraphDebugPrintOptions cdef unsigned int flags = 0 if options is not None: From 9ac34dfc1e8be72a6f2ab0343d4c63eeef238675 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 24 Mar 2026 11:53:26 -0700 Subject: [PATCH 2/8] Remove _lazy_init from _graph_builder; add cached get_driver_version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the per-module _lazy_init / _inited / _driver_ver / _py_major_minor pattern in _graph_builder.pyx with direct calls to centralized cached functions in cuda_utils: - Add get_driver_version() with @functools.cache alongside get_binding_version - Switch get_binding_version from @functools.lru_cache to @functools.cache (cleaner for nullary functions) - Fix split() to return tuple(result) — Cython enforces return type annotations unlike pure Python - Fix _cond_with_params annotation from -> GraphBuilder to -> tuple to match actual return value Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 68 +++++++------------ cuda_core/cuda/core/_utils/cuda_utils.pyx | 6 +- 2 files changed, 31 insertions(+), 43 deletions(-) diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index a0141be23c..f0f1298fbb 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -9,25 +9,10 @@ from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils import ( driver, get_binding_version, + get_driver_version, handle_return, ) -_inited = False -_driver_ver = None - - -def _lazy_init(): - global _inited - if _inited: - return - - global _py_major_minor, _driver_ver - # binding availability depends on cuda-python version - _py_major_minor = get_binding_version() - _driver_ver = handle_return(driver.cuDriverGetVersion()) - _inited = True - - @dataclass class GraphDebugPrintOptions: """Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()` @@ -179,7 +164,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") elif ( - _py_major_minor >= (12, 8) + get_binding_version() >= (12, 8) and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED ): raise RuntimeError("One or more conditional handles are not associated with conditional builders.") @@ -242,7 +227,6 @@ class GraphBuilder: @classmethod def _init(cls, stream, is_stream_owner, conditional_graph=None, is_join_required=False): self = cls.__new__(cls) - _lazy_init() self._mnff = GraphBuilder._MembersNeededForFinalize( self, stream, is_stream_owner, conditional_graph, is_join_required ) @@ -398,7 +382,7 @@ class GraphBuilder: GraphBuilder._init(stream=stream, is_stream_owner=True, conditional_graph=None, is_join_required=True) ) event.close() - return result + return tuple(result) @staticmethod def join(*graph_builders) -> GraphBuilder: @@ -460,10 +444,10 @@ class GraphBuilder: The newly created conditional handle. """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional handles") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional handles") + if get_driver_version() < 12030: + raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional handles") + if get_binding_version() < (12, 3): + raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional handles") if default_value is not None: flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT else: @@ -478,7 +462,7 @@ class GraphBuilder: driver.cuGraphConditionalHandleCreate(graph, self._get_conditional_context(), default_value, flags) ) - def _cond_with_params(self, node_params) -> GraphBuilder: + def _cond_with_params(self, node_params) -> tuple: # Get current capture info to ensure we're in a valid state status, _, graph, *deps_info, num_dependencies = handle_return( driver.cuStreamGetCaptureInfo(self._mnff.stream.handle) @@ -533,10 +517,10 @@ class GraphBuilder: The newly created conditional graph builder. """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if") + if get_driver_version() < 12030: + raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if") + if get_binding_version() < (12, 3): + raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -564,10 +548,10 @@ class GraphBuilder: A tuple of two new graph builders, one for the if branch and one for the else branch. """ - if _driver_ver < 12080: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional if-else") - if _py_major_minor < (12, 8): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional if-else") + if get_driver_version() < 12080: + raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if-else") + if get_binding_version() < (12, 8): + raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if-else") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -598,10 +582,10 @@ class GraphBuilder: A tuple of new graph builders, one for each branch. """ - if _driver_ver < 12080: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional switch") - if _py_major_minor < (12, 8): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional switch") + if get_driver_version() < 12080: + raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional switch") + if get_binding_version() < (12, 8): + raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional switch") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -629,10 +613,10 @@ class GraphBuilder: The newly created while loop graph builder. """ - if _driver_ver < 12030: - raise RuntimeError(f"Driver version {_driver_ver} does not support conditional while loop") - if _py_major_minor < (12, 3): - raise RuntimeError(f"Binding version {_py_major_minor} does not support conditional while loop") + if get_driver_version() < 12030: + raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional while loop") + if get_binding_version() < (12, 3): + raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional while loop") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -660,10 +644,10 @@ class GraphBuilder: child_graph : :obj:`~_graph.GraphBuilder` The child graph builder. Must have finished building. """ - if (_driver_ver < 12000) or (_py_major_minor < (12, 0)): + if (get_driver_version() < 12000) or (get_binding_version() < (12, 0)): raise NotImplementedError( f"Launching child graphs is not implemented for versions older than CUDA 12." - f"Found driver version is {_driver_ver} and binding version is {_py_major_minor}" + f"Found driver version is {get_driver_version()} and binding version is {get_binding_version()}" ) if not child_graph._building_ended: diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index 999b3be325..ec6c587f3f 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -298,7 +298,7 @@ def is_nested_sequence(obj): return is_sequence(obj) and any(is_sequence(elem) for elem in obj) -@functools.lru_cache +@functools.cache def get_binding_version(): try: major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2] @@ -306,6 +306,10 @@ def get_binding_version(): major_minor = importlib.metadata.version("cuda-python").split(".")[:2] return tuple(int(v) for v in major_minor) +@functools.cache +def get_driver_version(): + return handle_return(driver.cuDriverGetVersion()) + class Transaction: """ From edbc361d2d465781aea0d0b391dbe4f6978dfc3e Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Tue, 24 Mar 2026 13:36:34 -0700 Subject: [PATCH 3/8] Add CPU callbacks for stream capture (GraphBuilder.callback) Implements #1328: host callbacks during stream capture via cuLaunchHostFunc, mirroring the existing GraphDef.callback API. Extracts shared callback infrastructure (_attach_user_object, _attach_host_callback_to_graph, trampoline/destructor) into a new _graph/_utils.pyx module to avoid circular imports between _graph_builder and _graphdef. Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 56 +++++++++ cuda_core/cuda/core/_graph/_graphdef.pyx | 87 ++------------ cuda_core/cuda/core/_graph/_utils.pxd | 16 +++ cuda_core/cuda/core/_graph/_utils.pyx | 106 ++++++++++++++++++ cuda_core/tests/graph/test_basic.py | 42 +++++++ 5 files changed, 230 insertions(+), 77 deletions(-) create mode 100644 cuda_core/cuda/core/_graph/_utils.pxd create mode 100644 cuda_core/cuda/core/_graph/_utils.pyx diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index f0f1298fbb..f1a11b5ded 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -5,7 +5,12 @@ import weakref from dataclasses import dataclass +from cuda.bindings cimport cydriver + +from cuda.core._graph._utils cimport _attach_host_callback_to_graph +from cuda.core._resource_handles cimport as_cu from cuda.core._stream cimport Stream +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN from cuda.core._utils.cuda_utils import ( driver, get_binding_version, @@ -682,6 +687,57 @@ class GraphBuilder: ) ) + def callback(self, fn, *, user_data=None): + """Add a host callback to the graph during stream capture. + + The callback runs on the host CPU when the graph reaches this point + in execution. Two modes are supported: + + - **Python callable**: Pass any callable. The GIL is acquired + automatically. The callable must take no arguments; use closures + or ``functools.partial`` to bind state. + - **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance. + The function receives a single ``void*`` argument (the + ``user_data``). The caller must keep the ctypes wrapper alive + for the lifetime of the graph. + + .. warning:: + + Callbacks must not call CUDA API functions. Doing so may + deadlock or corrupt driver state. + + Parameters + ---------- + fn : callable or ctypes function pointer + The callback function. + user_data : int or bytes-like, optional + Only for ctypes function pointers. If ``int``, passed as a raw + pointer (caller manages lifetime). If bytes-like, the data is + copied and its lifetime is tied to the graph. + """ + cdef Stream stream = self._mnff.stream + cdef cydriver.CUstream c_stream = as_cu(stream._h_stream) + cdef cydriver.CUstreamCaptureStatus capture_status + cdef cydriver.CUgraph c_graph = NULL + + with nogil: + IF CUDA_CORE_BUILD_MAJOR >= 13: + HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL)) + ELSE: + HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo( + c_stream, &capture_status, NULL, &c_graph, NULL, NULL)) + + if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE: + raise RuntimeError("Cannot add callback when graph is not being built") + + cdef cydriver.CUhostFn c_fn + cdef void* c_user_data = NULL + _attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data) + + with nogil: + HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data)) + class Graph: """Represents an executable graph. diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx index cad2523a4a..dd4ee22ae1 100644 --- a/cuda_core/cuda/core/_graph/_graphdef.pyx +++ b/cuda_core/cuda/core/_graph/_graphdef.pyx @@ -30,8 +30,6 @@ GraphNode hierarchy: from __future__ import annotations -from cpython.ref cimport Py_INCREF - from libc.stddef cimport size_t from libc.stdint cimport uintptr_t from libc.stdlib cimport malloc, free @@ -102,16 +100,11 @@ cdef bint _check_node_get_params(): return _has_cuGraphNodeGetParams -cdef extern from "Python.h": - void _py_decref "Py_DECREF" (void*) - - -cdef void _py_host_trampoline(void* data) noexcept with gil: - (data)() - - -cdef void _py_host_destructor(void* data) noexcept with gil: - _py_decref(data) +from cuda.core._graph._utils cimport ( + _attach_host_callback_to_graph, + _attach_user_object, + _is_py_host_trampoline, +) cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil: @@ -124,30 +117,6 @@ cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil: del p -cdef void _attach_user_object( - cydriver.CUgraph graph, void* ptr, - cydriver.CUhostFn destroy) except *: - """Create a CUDA user object and transfer ownership to the graph. - - On success the graph owns the resource (via MOVE semantics). - On failure the destroy callback is invoked to clean up ptr, - then a CUDAError is raised — callers need no try/except. - """ - cdef cydriver.CUuserObject user_obj = NULL - cdef cydriver.CUresult ret - with nogil: - ret = cydriver.cuUserObjectCreate( - &user_obj, ptr, destroy, 1, - cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC) - if ret == cydriver.CUDA_SUCCESS: - ret = cydriver.cuGraphRetainUserObject( - graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE) - if ret != cydriver.CUDA_SUCCESS: - cydriver.cuUserObjectRelease(user_obj, 1) - if ret != cydriver.CUDA_SUCCESS: - if user_obj == NULL: - destroy(ptr) - HANDLE_RETURN(ret) cdef class Condition: @@ -1270,56 +1239,20 @@ cdef class GraphNode: cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node) cdef cydriver.CUgraphNode* deps = NULL cdef size_t num_deps = 0 - cdef void* c_user_data = NULL - cdef object callable_obj = None - cdef void* fn_pyobj = NULL if pred_node != NULL: deps = &pred_node num_deps = 1 - if isinstance(fn, ct._CFuncPtr): - Py_INCREF(fn) - fn_pyobj = fn - _attach_user_object( - as_cu(h_graph), fn_pyobj, - _py_host_destructor) - node_params.fn = ct.cast( - fn, ct.c_void_p).value - - if user_data is not None: - if isinstance(user_data, int): - c_user_data = user_data - else: - buf = bytes(user_data) - c_user_data = malloc(len(buf)) - if c_user_data == NULL: - raise MemoryError( - "failed to allocate user_data buffer") - c_memcpy(c_user_data, buf, len(buf)) - _attach_user_object( - as_cu(h_graph), c_user_data, - free) - - node_params.userData = c_user_data - else: - if user_data is not None: - raise ValueError( - "user_data is only supported with ctypes " - "function pointers") - callable_obj = fn - Py_INCREF(fn) - fn_pyobj = fn - node_params.fn = _py_host_trampoline - node_params.userData = fn_pyobj - _attach_user_object( - as_cu(h_graph), fn_pyobj, - _py_host_destructor) + _attach_host_callback_to_graph( + as_cu(h_graph), fn, user_data, + &node_params.fn, &node_params.userData) with nogil: HANDLE_RETURN(cydriver.cuGraphAddHostNode( &new_node, as_cu(h_graph), deps, num_deps, &node_params)) + cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None self._succ_cache = None return HostCallbackNode._create_with_params( create_graph_node_handle(new_node, h_graph), callable_obj, @@ -1947,7 +1880,7 @@ cdef class HostCallbackNode(GraphNode): HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, ¶ms)) cdef object callable_obj = None - if params.fn == _py_host_trampoline: + if _is_py_host_trampoline(params.fn): callable_obj = params.userData return HostCallbackNode._create_with_params( diff --git a/cuda_core/cuda/core/_graph/_utils.pxd b/cuda_core/cuda/core/_graph/_utils.pxd new file mode 100644 index 0000000000..63fdb00ac4 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_utils.pxd @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cuda.bindings cimport cydriver + + +cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil + +cdef void _attach_user_object( + cydriver.CUgraph graph, void* ptr, + cydriver.CUhostFn destroy) except * + +cdef void _attach_host_callback_to_graph( + cydriver.CUgraph graph, object fn, object user_data, + cydriver.CUhostFn* out_fn, void** out_user_data) except * diff --git a/cuda_core/cuda/core/_graph/_utils.pyx b/cuda_core/cuda/core/_graph/_utils.pyx new file mode 100644 index 0000000000..bef879ab95 --- /dev/null +++ b/cuda_core/cuda/core/_graph/_utils.pyx @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from cpython.ref cimport Py_INCREF + +from libc.stdint cimport uintptr_t +from libc.stdlib cimport malloc, free +from libc.string cimport memcpy as c_memcpy + +from cuda.bindings cimport cydriver + +from cuda.core._utils.cuda_utils cimport HANDLE_RETURN + + +cdef extern from "Python.h": + void _py_decref "Py_DECREF" (void*) + + +cdef void _py_host_trampoline(void* data) noexcept with gil: + (data)() + + +cdef void _py_host_destructor(void* data) noexcept with gil: + _py_decref(data) + + +cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil: + return fn == _py_host_trampoline + + +cdef void _attach_user_object( + cydriver.CUgraph graph, void* ptr, + cydriver.CUhostFn destroy) except *: + """Create a CUDA user object and transfer ownership to the graph. + + On success the graph owns the resource (via MOVE semantics). + On failure the destroy callback is invoked to clean up ptr, + then a CUDAError is raised — callers need no try/except. + """ + cdef cydriver.CUuserObject user_obj = NULL + cdef cydriver.CUresult ret + with nogil: + ret = cydriver.cuUserObjectCreate( + &user_obj, ptr, destroy, 1, + cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC) + if ret == cydriver.CUDA_SUCCESS: + ret = cydriver.cuGraphRetainUserObject( + graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE) + if ret != cydriver.CUDA_SUCCESS: + cydriver.cuUserObjectRelease(user_obj, 1) + if ret != cydriver.CUDA_SUCCESS: + if user_obj == NULL: + destroy(ptr) + HANDLE_RETURN(ret) + + +cdef void _attach_host_callback_to_graph( + cydriver.CUgraph graph, object fn, object user_data, + cydriver.CUhostFn* out_fn, void** out_user_data) except *: + """Resolve a Python callable or ctypes CFuncPtr into a C callback pair. + + Handles Py_INCREF, user-object attachment for lifetime management, + and user_data copying. On return, *out_fn and *out_user_data are + ready to pass to cuGraphAddHostNode or cuLaunchHostFunc. + """ + import ctypes as ct + + cdef void* fn_pyobj = NULL + + if isinstance(fn, ct._CFuncPtr): + Py_INCREF(fn) + fn_pyobj = fn + _attach_user_object( + graph, fn_pyobj, + _py_host_destructor) + out_fn[0] = ct.cast( + fn, ct.c_void_p).value + + if user_data is not None: + if isinstance(user_data, int): + out_user_data[0] = user_data + else: + buf = bytes(user_data) + out_user_data[0] = malloc(len(buf)) + if out_user_data[0] == NULL: + raise MemoryError( + "failed to allocate user_data buffer") + c_memcpy(out_user_data[0], buf, len(buf)) + _attach_user_object( + graph, out_user_data[0], + free) + else: + out_user_data[0] = NULL + else: + if user_data is not None: + raise ValueError( + "user_data is only supported with ctypes " + "function pointers") + Py_INCREF(fn) + fn_pyobj = fn + out_fn[0] = _py_host_trampoline + out_user_data[0] = fn_pyobj + _attach_user_object( + graph, fn_pyobj, + _py_host_destructor) diff --git a/cuda_core/tests/graph/test_basic.py b/cuda_core/tests/graph/test_basic.py index 9de1880b5c..af1c744dbf 100644 --- a/cuda_core/tests/graph/test_basic.py +++ b/cuda_core/tests/graph/test_basic.py @@ -163,3 +163,45 @@ def test_graph_capture_errors(init_cuda): with pytest.raises(RuntimeError, match="^Graph has not finished building."): gb.complete() gb.end_building().complete() + + +def test_graph_capture_callback_python(init_cuda): + results = [] + + def my_callback(): + results.append(42) + + launch_stream = Device().create_stream() + gb = launch_stream.create_graph_builder().begin_building() + + with pytest.raises(ValueError, match="user_data is only supported"): + gb.callback(my_callback, user_data=b"hello") + + gb.callback(my_callback) + graph = gb.end_building().complete() + + graph.launch(launch_stream) + launch_stream.sync() + + assert results == [42] + + +def test_graph_capture_callback_ctypes(init_cuda): + import ctypes + + CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + result = [0] + + @CALLBACK + def read_byte(data): + result[0] = ctypes.cast(data, ctypes.POINTER(ctypes.c_uint8))[0] + + launch_stream = Device().create_stream() + gb = launch_stream.create_graph_builder().begin_building() + gb.callback(read_byte, user_data=bytes([0xAB])) + graph = gb.end_building().complete() + + graph.launch(launch_stream) + launch_stream.sync() + + assert result[0] == 0xAB From 59de5e23afadedd76c549e39705558a654da547d Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 25 Mar 2026 12:59:25 -0700 Subject: [PATCH 4/8] Standardize internal version checks in cuda.core MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move binding and driver version queries into a dedicated cuda/core/_utils/version.{pyx,pxd} module, providing both Python (binding_version, driver_version) and Cython (cy_binding_version, cy_driver_version) entry points. All functions return version tuples ((major, minor, patch)) and are cached—Python via @functools.cache, Cython via module-level globals. Remove get_binding_version / get_driver_version from cuda_utils.pyx and update all internal call sites and tests to import from the new module. Remove version checks for CUDA < 12.0 (now the minimum) and eliminate dead code exposed by the migration: _lazy_init / _use_ex / _kernel_ctypes / _is_cukernel_get_library_supported machinery in _module.pyx, _launcher.pyx, and _launch_config.pyx. The public NVML-based system.get_driver_version API is unrelated and left unchanged. Made-with: Cursor --- cuda_core/cuda/core/_graph/_graph_builder.pyx | 52 ++++---- cuda_core/cuda/core/_graph/_graphdef.pyx | 4 +- cuda_core/cuda/core/_launch_config.pyx | 38 ------ cuda_core/cuda/core/_launcher.pyx | 62 ++-------- cuda_core/cuda/core/_linker.pyx | 4 +- .../core/_memory/_virtual_memory_resource.py | 6 +- cuda_core/cuda/core/_module.pyx | 111 +----------------- cuda_core/cuda/core/_program.pyx | 12 +- cuda_core/cuda/core/_utils/cuda_utils.pyx | 12 -- cuda_core/cuda/core/_utils/version.pxd | 6 + cuda_core/cuda/core/_utils/version.pyx | 43 +++++++ cuda_core/tests/graph/test_explicit.py | 8 +- cuda_core/tests/test_cuda_utils.py | 8 +- cuda_core/tests/test_device.py | 21 ++-- cuda_core/tests/test_module.py | 9 +- .../tests/test_optional_dependency_imports.py | 21 +--- cuda_core/tests/test_program.py | 5 +- 17 files changed, 122 insertions(+), 300 deletions(-) create mode 100644 cuda_core/cuda/core/_utils/version.pxd create mode 100644 cuda_core/cuda/core/_utils/version.pyx diff --git a/cuda_core/cuda/core/_graph/_graph_builder.pyx b/cuda_core/cuda/core/_graph/_graph_builder.pyx index f1a11b5ded..58b1d93b9a 100644 --- a/cuda_core/cuda/core/_graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/_graph/_graph_builder.pyx @@ -11,10 +11,10 @@ from cuda.core._graph._utils cimport _attach_host_callback_to_graph from cuda.core._resource_handles cimport as_cu from cuda.core._stream cimport Stream from cuda.core._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core._utils.version cimport cy_binding_version, cy_driver_version + from cuda.core._utils.cuda_utils import ( driver, - get_binding_version, - get_driver_version, handle_return, ) @@ -169,7 +169,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED: raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.") elif ( - get_binding_version() >= (12, 8) + cy_binding_version() >= (12, 8, 0) and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED ): raise RuntimeError("One or more conditional handles are not associated with conditional builders.") @@ -449,10 +449,10 @@ class GraphBuilder: The newly created conditional handle. """ - if get_driver_version() < 12030: - raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional handles") - if get_binding_version() < (12, 3): - raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional handles") + if cy_driver_version() < (12, 3, 0): + raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles") + if cy_binding_version() < (12, 3, 0): + raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles") if default_value is not None: flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT else: @@ -522,10 +522,10 @@ class GraphBuilder: The newly created conditional graph builder. """ - if get_driver_version() < 12030: - raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if") - if get_binding_version() < (12, 3): - raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if") + if cy_driver_version() < (12, 3, 0): + raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if") + if cy_binding_version() < (12, 3, 0): + raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -553,10 +553,10 @@ class GraphBuilder: A tuple of two new graph builders, one for the if branch and one for the else branch. """ - if get_driver_version() < 12080: - raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if-else") - if get_binding_version() < (12, 8): - raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if-else") + if cy_driver_version() < (12, 8, 0): + raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else") + if cy_binding_version() < (12, 8, 0): + raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -587,10 +587,10 @@ class GraphBuilder: A tuple of new graph builders, one for each branch. """ - if get_driver_version() < 12080: - raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional switch") - if get_binding_version() < (12, 8): - raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional switch") + if cy_driver_version() < (12, 8, 0): + raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch") + if cy_binding_version() < (12, 8, 0): + raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -618,10 +618,10 @@ class GraphBuilder: The newly created while loop graph builder. """ - if get_driver_version() < 12030: - raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional while loop") - if get_binding_version() < (12, 3): - raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional while loop") + if cy_driver_version() < (12, 3, 0): + raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop") + if cy_binding_version() < (12, 3, 0): + raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop") node_params = driver.CUgraphNodeParams() node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL node_params.conditional.handle = handle @@ -649,12 +649,6 @@ class GraphBuilder: child_graph : :obj:`~_graph.GraphBuilder` The child graph builder. Must have finished building. """ - if (get_driver_version() < 12000) or (get_binding_version() < (12, 0)): - raise NotImplementedError( - f"Launching child graphs is not implemented for versions older than CUDA 12." - f"Found driver version is {get_driver_version()} and binding version is {get_binding_version()}" - ) - if not child_graph._building_ended: raise ValueError("Child graph has not finished building.") diff --git a/cuda_core/cuda/core/_graph/_graphdef.pyx b/cuda_core/cuda/core/_graph/_graphdef.pyx index dd4ee22ae1..e924540281 100644 --- a/cuda_core/cuda/core/_graph/_graphdef.pyx +++ b/cuda_core/cuda/core/_graph/_graphdef.pyx @@ -94,8 +94,8 @@ cdef bint _version_checked = False cdef bint _check_node_get_params(): global _has_cuGraphNodeGetParams, _version_checked if not _version_checked: - ver = handle_return(driver.cuDriverGetVersion()) - _has_cuGraphNodeGetParams = ver >= 13020 + from cuda.core._utils.version import driver_version + _has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0) _version_checked = True return _has_cuGraphNodeGetParams diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index 798df71d9e..285110a52b 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -8,45 +8,16 @@ from cuda.core._utils.cuda_utils cimport ( HANDLE_RETURN, ) -import threading - from cuda.core._device import Device from cuda.core._utils.cuda_utils import ( CUDAError, cast_to_3_tuple, driver, - get_binding_version, ) - -cdef bint _inited = False -cdef bint _use_ex = False -cdef object _lock = threading.Lock() - -# Attribute names for identity comparison and representation _LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch') -cdef int _lazy_init() except?-1: - global _inited, _use_ex - if _inited: - return 0 - - cdef tuple _py_major_minor - cdef int _driver_ver - with _lock: - if _inited: - return 0 - - # binding availability depends on cuda-python version - _py_major_minor = get_binding_version() - HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver)) - _use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8)) - _inited = True - - return 0 - - cdef class LaunchConfig: """Customizable launch options. @@ -99,8 +70,6 @@ cdef class LaunchConfig: cooperative_launch : bool, optional Whether to launch as cooperative kernel (default: False) """ - _lazy_init() - # Convert and validate grid and block dimensions self.grid = cast_to_3_tuple("LaunchConfig.grid", grid) self.block = cast_to_3_tuple("LaunchConfig.block", block) @@ -110,10 +79,6 @@ cdef class LaunchConfig: # device compute capability or attributes. # thread block clusters are supported starting H100 if cluster is not None: - if not _use_ex: - err, drvers = driver.cuDriverGetVersion() - drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else "" - raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}") cc = Device().compute_capability if cc < (9, 0): raise CUDAError( @@ -153,7 +118,6 @@ cdef class LaunchConfig: return hash(self._identity()) cdef cydriver.CUlaunchConfig _to_native_launch_config(self): - _lazy_init() cdef cydriver.CUlaunchConfig drv_cfg cdef cydriver.CUlaunchAttribute attr memset(&drv_cfg, 0, sizeof(drv_cfg)) @@ -201,8 +165,6 @@ cpdef object _to_native_launch_config(LaunchConfig config): driver.CUlaunchConfig Native CUDA driver launch configuration """ - _lazy_init() - cdef object drv_cfg = driver.CUlaunchConfig() cdef list attrs cdef object attr diff --git a/cuda_core/cuda/core/_launcher.pyx b/cuda_core/cuda/core/_launcher.pyx index ce5f7339e0..82ef0777a4 100644 --- a/cuda_core/cuda/core/_launcher.pyx +++ b/cuda_core/cuda/core/_launcher.pyx @@ -15,39 +15,9 @@ from cuda.core._utils.cuda_utils cimport ( check_or_create_options, HANDLE_RETURN, ) - -import threading - from cuda.core._module import Kernel from cuda.core._stream import Stream -from cuda.core._utils.cuda_utils import ( - _reduce_3_tuple, - get_binding_version, -) - - -cdef bint _inited = False -cdef bint _use_ex = False -cdef object _lock = threading.Lock() - - -cdef int _lazy_init() except?-1: - global _inited, _use_ex - if _inited: - return 0 - - cdef int _driver_ver - with _lock: - if _inited: - return 0 - - # binding availability depends on cuda-python version - _py_major_minor = get_binding_version() - HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver)) - _use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8)) - _inited = True - - return 0 +from cuda.core._utils.cuda_utils import _reduce_3_tuple def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args): @@ -70,7 +40,6 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern """ cdef Stream s = Stream_accept(stream, allow_stream_protocol=True) - _lazy_init() cdef LaunchConfig conf = check_or_create_options(LaunchConfig, config, "launch config") # TODO: can we ensure kernel_args is valid/safe to use here? @@ -78,32 +47,15 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern cdef ParamHolder ker_args = ParamHolder(kernel_args) cdef void** args_ptr = (ker_args.ptr) - # Note: We now use CUkernel handles exclusively (CUDA 12+), but they can be cast to - # CUfunction for use with cuLaunchKernel, as both handle types are interchangeable - # for kernel launch purposes. cdef Kernel ker = kernel cdef cydriver.CUfunction func_handle = as_cu(ker._h_kernel) - # Note: CUkernel can still be launched via cuLaunchKernel (not just cuLaunchKernelEx). - # We check both binding & driver versions here mainly to see if the "Ex" API is - # available and if so we use it, as it's more feature rich. - if _use_ex: - drv_cfg = conf._to_native_launch_config() - drv_cfg.hStream = as_cu(s._h_stream) - if conf.cooperative_launch: - _check_cooperative_launch(kernel, conf, s) - with nogil: - HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL)) - else: - # TODO: check if config has any unsupported attrs - HANDLE_RETURN( - cydriver.cuLaunchKernel( - func_handle, - conf.grid[0], conf.grid[1], conf.grid[2], - conf.block[0], conf.block[1], conf.block[2], - conf.shmem_size, as_cu(s._h_stream), args_ptr, NULL - ) - ) + drv_cfg = conf._to_native_launch_config() + drv_cfg.hStream = as_cu(s._h_stream) + if conf.cooperative_launch: + _check_cooperative_launch(kernel, conf, s) + with nogil: + HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL)) cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream): diff --git a/cuda_core/cuda/core/_linker.pyx b/cuda_core/cuda/core/_linker.pyx index ce7c6e4528..7c7a8edde3 100644 --- a/cuda_core/cuda/core/_linker.pyx +++ b/cuda_core/cuda/core/_linker.pyx @@ -40,6 +40,7 @@ from cuda.core._utils.cuda_utils import ( handle_return, is_sequence, ) +from cuda.core._utils.version import driver_version ctypedef const char* const_char_ptr ctypedef void* void_ptr @@ -641,8 +642,7 @@ def _decide_nvjitlink_or_driver() -> bool: if _driver_ver is not None: return not _use_nvjitlink_backend - _driver_ver = handle_return(driver.cuDriverGetVersion()) - _driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10) + _driver_ver = driver_version()[:2] warn_txt_common = ( "the driver APIs will be used instead, which do not support" diff --git a/cuda_core/cuda/core/_memory/_virtual_memory_resource.py b/cuda_core/cuda/core/_memory/_virtual_memory_resource.py index 7aff5709b7..936bae7632 100644 --- a/cuda_core/cuda/core/_memory/_virtual_memory_resource.py +++ b/cuda_core/cuda/core/_memory/_virtual_memory_resource.py @@ -16,8 +16,8 @@ Transaction, check_or_create_options, driver, - get_binding_version, ) +from cuda.core._utils.version import binding_version from cuda.core._utils.cuda_utils import ( _check_driver_error as raise_if_driver_error, ) @@ -99,8 +99,8 @@ class VirtualMemoryResourceOptions: _t = driver.CUmemAllocationType # CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not _allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012 - ver_major, ver_minor = get_binding_version() - if ver_major >= 13: + if binding_version() >= (13, 0, 0): + _allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED @staticmethod diff --git a/cuda_core/cuda/core/_module.pyx b/cuda_core/cuda/core/_module.pyx index 4e8f810619..2eaff7fb11 100644 --- a/cuda_core/cuda/core/_module.pyx +++ b/cuda_core/cuda/core/_module.pyx @@ -6,8 +6,6 @@ from __future__ import annotations from libc.stddef cimport size_t -import functools -import threading from collections import namedtuple from cuda.core._device import Device @@ -33,110 +31,12 @@ from cuda.core._utils.clear_error_support import ( raise_code_path_meant_to_be_unreachable, ) from cuda.core._utils.cuda_utils cimport HANDLE_RETURN -from cuda.core._utils.cuda_utils import driver, get_binding_version +from cuda.core._utils.version cimport cy_driver_version +from cuda.core._utils.cuda_utils import driver from cuda.bindings cimport cydriver __all__ = ["Kernel", "ObjectCode"] -# Lazy initialization state and synchronization -# For Python 3.13t (free-threaded builds), we use a lock to ensure thread-safe initialization. -# For regular Python builds with GIL, the lock overhead is minimal and the code remains safe. -cdef object _init_lock = threading.Lock() -cdef bint _inited = False -cdef int _py_major_ver = 0 -cdef int _py_minor_ver = 0 -cdef int _driver_ver = 0 -cdef tuple _kernel_ctypes = None -cdef bint _paraminfo_supported = False - - -cdef int _lazy_init() except -1: - """ - Initialize module-level state in a thread-safe manner. - - This function is thread-safe and suitable for both: - - Regular Python builds (with GIL) - - Python 3.13t free-threaded builds (without GIL) - - Uses double-checked locking pattern for performance: - - Fast path: check without lock if already initialized - - Slow path: acquire lock and initialize if needed - """ - global _inited - # Fast path: already initialized (no lock needed for read) - if _inited: - return 0 - - cdef int drv_ver - # Slow path: acquire lock and initialize - with _init_lock: - # Double-check: another thread might have initialized while we waited - if _inited: - return 0 - - global _py_major_ver, _py_minor_ver, _driver_ver, _kernel_ctypes, _paraminfo_supported - # binding availability depends on cuda-python version - _py_major_ver, _py_minor_ver = get_binding_version() - _kernel_ctypes = (driver.CUkernel,) - with nogil: - HANDLE_RETURN(cydriver.cuDriverGetVersion(&drv_ver)) - _driver_ver = drv_ver - _paraminfo_supported = _driver_ver >= 12040 - - # Mark as initialized (must be last to ensure all state is set) - _inited = True - - return 0 - - -# Auto-initializing accessors (cdef for internal use) -cdef inline int _get_py_major_ver() except -1: - """Get the Python binding major version, initializing if needed.""" - _lazy_init() - return _py_major_ver - - -cdef inline int _get_py_minor_ver() except -1: - """Get the Python binding minor version, initializing if needed.""" - _lazy_init() - return _py_minor_ver - - -cdef inline int _get_driver_ver() except -1: - """Get the CUDA driver version, initializing if needed.""" - _lazy_init() - return _driver_ver - - -cdef inline tuple _get_kernel_ctypes(): - """Get the kernel ctypes tuple, initializing if needed.""" - _lazy_init() - return _kernel_ctypes - - -cdef inline bint _is_paraminfo_supported() except -1: - """Return True if cuKernelGetParamInfo is available (driver >= 12.4).""" - _lazy_init() - return _paraminfo_supported - - -@functools.cache -def _is_cukernel_get_library_supported() -> bool: - """Return True when cuKernelGetLibrary is available for inverse kernel-to-library lookup. - - Requires cuda-python bindings >= 12.5 and driver >= 12.5. - """ - return ( - (_get_py_major_ver(), _get_py_minor_ver()) >= (12, 5) - and _get_driver_ver() >= 12050 - and hasattr(driver, "cuKernelGetLibrary") - ) - - -cdef inline LibraryHandle _make_empty_library_handle(): - """Create an empty LibraryHandle to indicate no library loaded.""" - return LibraryHandle() # Empty shared_ptr - cdef class KernelAttributes: """Provides access to kernel attributes.""" @@ -149,7 +49,6 @@ cdef class KernelAttributes: cdef KernelAttributes self = KernelAttributes.__new__(KernelAttributes) self._h_kernel = h_kernel self._cache = {} - _lazy_init() return self cdef int _get_cached_attribute(self, int device_id, cydriver.CUfunction_attribute attribute) except? -1: @@ -508,11 +407,10 @@ cdef class Kernel: return self._attributes cdef tuple _get_arguments_info(self, bint param_info=False): - if not _is_paraminfo_supported(): - driver_ver = _get_driver_ver() + if cy_driver_version() < (12, 4, 0): raise NotImplementedError( "Driver version 12.4 or newer is required for this function. " - f"Using driver version {driver_ver // 1000}.{(driver_ver % 1000) // 10}" + f"Using driver version {'.'.join(map(str, cy_driver_version()))}" ) cdef size_t arg_pos = 0 cdef list param_info_data = [] @@ -650,7 +548,6 @@ cdef class ObjectCode: # _h_library is assigned during _lazy_load_module self._h_library = LibraryHandle() # Empty handle - _lazy_init() self._code_type = code_type self._module = module diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 96d7aa0567..5953b162f9 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -34,11 +34,11 @@ from cuda.core._utils.cuda_utils import ( CUDAError, _handle_boolean_option, check_or_create_options, - get_binding_version, handle_return, is_nested_sequence, is_sequence, ) +from cuda.core._utils.version import binding_version, driver_version __all__ = ["Program", "ProgramOptions"] @@ -520,10 +520,10 @@ def _get_nvvm_module(): _nvvm_import_attempted = True try: - version = get_binding_version() - if version < (12, 9): + version = binding_version() + if version < (12, 9, 0): raise RuntimeError( - f"NVVM bindings require cuda-bindings >= 12.9.0, but found {version[0]}.{version[1]}.x. " + f"NVVM bindings require cuda-bindings >= 12.9.0, but found {'.'.join(map(str, version))}. " "Please update cuda-bindings to use NVVM features." ) @@ -579,9 +579,9 @@ cdef inline void _process_define_macro(list options, object macro) except *: cpdef bint _can_load_generated_ptx() except? -1: """Check if the driver can load PTX generated by the current NVRTC version.""" - driver_ver = handle_return(driver.cuDriverGetVersion()) + drv = driver_version() nvrtc_major, nvrtc_minor = handle_return(nvrtc.nvrtcVersion()) - return nvrtc_major * 1000 + nvrtc_minor * 10 <= driver_ver + return (nvrtc_major, nvrtc_minor, 0) <= drv cdef inline object _translate_program_options(object options): diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index ec6c587f3f..e4b086062e 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -298,18 +298,6 @@ def is_nested_sequence(obj): return is_sequence(obj) and any(is_sequence(elem) for elem in obj) -@functools.cache -def get_binding_version(): - try: - major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2] - except importlib.metadata.PackageNotFoundError: - major_minor = importlib.metadata.version("cuda-python").split(".")[:2] - return tuple(int(v) for v in major_minor) - -@functools.cache -def get_driver_version(): - return handle_return(driver.cuDriverGetVersion()) - class Transaction: """ diff --git a/cuda_core/cuda/core/_utils/version.pxd b/cuda_core/cuda/core/_utils/version.pxd new file mode 100644 index 0000000000..2746d463db --- /dev/null +++ b/cuda_core/cuda/core/_utils/version.pxd @@ -0,0 +1,6 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +cdef tuple cy_binding_version() +cdef tuple cy_driver_version() diff --git a/cuda_core/cuda/core/_utils/version.pyx b/cuda_core/cuda/core/_utils/version.pyx new file mode 100644 index 0000000000..a7f764e66b --- /dev/null +++ b/cuda_core/cuda/core/_utils/version.pyx @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import functools +import importlib.metadata + +from cuda.core._utils.cuda_utils import driver, handle_return + + +@functools.cache +def binding_version(): + """Return the cuda-bindings version as a (major, minor, patch) triple.""" + try: + parts = importlib.metadata.version("cuda-bindings").split(".")[:3] + except importlib.metadata.PackageNotFoundError: + parts = importlib.metadata.version("cuda-python").split(".")[:3] + return tuple(int(v) for v in parts) + + +@functools.cache +def driver_version(): + """Return the CUDA driver version as a (major, minor, patch) triple.""" + cdef int ver = handle_return(driver.cuDriverGetVersion()) + return (ver // 1000, (ver % 1000) // 10, ver % 10) + + +cdef tuple _cached_binding_version = None +cdef tuple _cached_driver_version = None + + +cdef tuple cy_binding_version(): + global _cached_binding_version + if _cached_binding_version is None: + _cached_binding_version = binding_version() + return _cached_binding_version + + +cdef tuple cy_driver_version(): + global _cached_driver_version + if _cached_driver_version is None: + _cached_driver_version = driver_version() + return _cached_driver_version diff --git a/cuda_core/tests/graph/test_explicit.py b/cuda_core/tests/graph/test_explicit.py index ab023f5ffa..33826cb5fd 100644 --- a/cuda_core/tests/graph/test_explicit.py +++ b/cuda_core/tests/graph/test_explicit.py @@ -48,18 +48,18 @@ def _skip_if_no_managed_mempool(): def _driver_has_node_get_params(): - from cuda.bindings import driver as drv + from cuda.core._utils.version import driver_version - return drv.cuDriverGetVersion()[1] >= 13020 + return driver_version() >= (13, 2, 0) _HAS_NODE_GET_PARAMS = _driver_has_node_get_params() def _bindings_major_version(): - from cuda.core._utils.cuda_utils import get_binding_version + from cuda.core._utils.version import binding_version - return get_binding_version()[0] + return binding_version()[0] _BINDINGS_MAJOR = _bindings_major_version() diff --git a/cuda_core/tests/test_cuda_utils.py b/cuda_core/tests/test_cuda_utils.py index 04670b96f2..f218182766 100644 --- a/cuda_core/tests/test_cuda_utils.py +++ b/cuda_core/tests/test_cuda_utils.py @@ -21,7 +21,9 @@ def test_driver_cu_result_explanations_health(): assert code in expl_dict known_codes.add(code) - if cuda_utils.get_binding_version() >= (13, 0): + from cuda.core._utils.version import binding_version + + if binding_version() >= (13, 0, 0): # Ensure expl_dict has no codes not known as a CUresult enum extra_expl = sorted(set(expl_dict.keys()) - known_codes) assert not extra_expl @@ -37,7 +39,9 @@ def test_runtime_cuda_error_explanations_health(): assert code in expl_dict known_codes.add(code) - if cuda_utils.get_binding_version() >= (13, 0): + from cuda.core._utils.version import binding_version + + if binding_version() >= (13, 0, 0): # Ensure expl_dict has no codes not known as a cudaError_t enum extra_expl = sorted(set(expl_dict.keys()) - known_codes) assert not extra_expl diff --git a/cuda_core/tests/test_device.py b/cuda_core/tests/test_device.py index 95e47ce8d9..c4e1e9931f 100644 --- a/cuda_core/tests/test_device.py +++ b/cuda_core/tests/test_device.py @@ -10,7 +10,8 @@ import cuda.core from cuda.core import Device -from cuda.core._utils.cuda_utils import ComputeCapability, get_binding_version, handle_return +from cuda.core._utils.cuda_utils import ComputeCapability, handle_return +from cuda.core._utils.version import binding_version, driver_version def test_device_init_disabled(): @@ -18,14 +19,6 @@ def test_device_init_disabled(): cuda.core._device.DeviceProperties() # Ensure back door is locked. -@pytest.fixture(scope="module") -def cuda_version(): - # binding availability depends on cuda-python version - _py_major_ver, _ = get_binding_version() - _driver_ver = handle_return(driver.cuDriverGetVersion()) - return _py_major_ver, _driver_ver - - def test_to_system_device(deinit_cuda): from cuda.core.system import _system @@ -115,8 +108,8 @@ def test_pci_bus_id(): def test_uuid(): device = Device() - driver_ver = handle_return(driver.cuDriverGetVersion()) - if driver_ver < 13000: + drv_ver = driver_version() + if drv_ver < (13, 0, 0): uuid = handle_return(driver.cuDeviceGetUuid_v2(device.device_id)) else: uuid = handle_return(driver.cuDeviceGetUuid(device.device_id)) @@ -306,8 +299,8 @@ def test_arch(): ("only_partial_host_native_atomic_supported", bool), ] -version = get_binding_version() -if version[0] >= 13: +version = binding_version() +if version >= (13, 0, 0): cuda_base_properties += cuda_13_properties @@ -324,7 +317,7 @@ def test_device_properties_complete(): excluded_props = set() # Exclude CUDA 13+ specific properties when not available - if version[0] < 13: + if version < (13, 0, 0): excluded_props.update({prop[0] for prop in cuda_13_properties}) filtered_tab_props = tab_props - excluded_props diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 2bc7e25d21..598b46ac7a 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -10,7 +10,8 @@ import cuda.core from cuda.core import Device, Kernel, ObjectCode, Program, ProgramOptions from cuda.core._program import _can_load_generated_ptx -from cuda.core._utils.cuda_utils import CUDAError, driver, get_binding_version, handle_return +from cuda.core._utils.cuda_utils import CUDAError, driver, handle_return +from cuda.core._utils.version import binding_version, driver_version try: import numba @@ -34,11 +35,7 @@ @pytest.fixture(scope="module") def cuda12_4_prerequisite_check(): - # binding availability depends on cuda-python version - # and version of underlying CUDA toolkit - _py_major_ver, _ = get_binding_version() - _driver_ver = handle_return(driver.cuDriverGetVersion()) - return _py_major_ver >= 12 and _driver_ver >= 12040 + return binding_version() >= (12, 0, 0) and driver_version() >= (12, 4, 0) def test_kernel_attributes_init_disabled(): diff --git a/cuda_core/tests/test_optional_dependency_imports.py b/cuda_core/tests/test_optional_dependency_imports.py index ebdd10e4a7..281f2b707d 100644 --- a/cuda_core/tests/test_optional_dependency_imports.py +++ b/cuda_core/tests/test_optional_dependency_imports.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE -import types - import pytest from cuda.core import _linker, _program @@ -35,17 +33,8 @@ def restore_optional_import_state(): _linker._use_nvjitlink_backend = saved_use_nvjitlink -def _patch_driver_version(monkeypatch, version=13000): - monkeypatch.setattr( - _linker, - "driver", - types.SimpleNamespace(cuDriverGetVersion=lambda: version), - ) - monkeypatch.setattr(_linker, "handle_return", lambda value: value) - - def test_get_nvvm_module_reraises_nested_module_not_found(monkeypatch): - monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9)) + monkeypatch.setattr(_program, "binding_version", lambda: (12, 9, 0)) def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvvm" @@ -62,7 +51,7 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_get_nvvm_module_reports_missing_nvvm_module(monkeypatch): - monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9)) + monkeypatch.setattr(_program, "binding_version", lambda: (12, 9, 0)) def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvvm" @@ -76,7 +65,7 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_get_nvvm_module_handles_missing_libnvvm(monkeypatch): - monkeypatch.setattr(_program, "get_binding_version", lambda: (12, 9)) + monkeypatch.setattr(_program, "binding_version", lambda: (12, 9, 0)) def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvvm" @@ -90,7 +79,7 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch): - _patch_driver_version(monkeypatch) + monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0)) def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" @@ -107,7 +96,7 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch): - _patch_driver_version(monkeypatch) + monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0)) def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 2507b82f0d..07e2ed0c94 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -15,7 +15,6 @@ pytest_plugins = ("cuda_python_test_helpers.nvvm_bitcode",) -cuda_driver_version = handle_return(driver.cuDriverGetVersion()) is_culink_backend = _linker._decide_nvjitlink_or_driver() @@ -36,10 +35,8 @@ def _is_nvvm_available(): try: from cuda.core._utils.cuda_utils import driver, handle_return, nvrtc - - _cuda_driver_version = handle_return(driver.cuDriverGetVersion()) except Exception: - _cuda_driver_version = 0 + pass def _get_nvrtc_version_for_tests(): From 51b8f63fd1379be6d5b5236088b1e3fb577e065f Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 25 Mar 2026 13:06:45 -0700 Subject: [PATCH 5/8] Fix unused imports after merge with main Remove unused imports flagged by cython-lint and ruff after resolving merge conflicts with origin/main. Made-with: Cursor --- cuda_core/cuda/core/_launch_config.pyx | 4 ---- cuda_core/cuda/core/_linker.pyx | 1 - cuda_core/cuda/core/_memory/_virtual_memory_resource.py | 3 +-- cuda_core/cuda/core/_program.pyx | 2 +- cuda_core/cuda/core/_utils/cuda_utils.pyx | 1 - cuda_core/tests/test_program.py | 9 ++++----- 6 files changed, 6 insertions(+), 14 deletions(-) diff --git a/cuda_core/cuda/core/_launch_config.pyx b/cuda_core/cuda/core/_launch_config.pyx index 285110a52b..0970ea36c7 100644 --- a/cuda_core/cuda/core/_launch_config.pyx +++ b/cuda_core/cuda/core/_launch_config.pyx @@ -4,10 +4,6 @@ from libc.string cimport memset -from cuda.core._utils.cuda_utils cimport ( - HANDLE_RETURN, -) - from cuda.core._device import Device from cuda.core._utils.cuda_utils import ( CUDAError, diff --git a/cuda_core/cuda/core/_linker.pyx b/cuda_core/cuda/core/_linker.pyx index 7c7a8edde3..6a6606da7a 100644 --- a/cuda_core/cuda/core/_linker.pyx +++ b/cuda_core/cuda/core/_linker.pyx @@ -37,7 +37,6 @@ from cuda.core._utils.cuda_utils import ( CUDAError, check_or_create_options, driver, - handle_return, is_sequence, ) from cuda.core._utils.version import driver_version diff --git a/cuda_core/cuda/core/_memory/_virtual_memory_resource.py b/cuda_core/cuda/core/_memory/_virtual_memory_resource.py index 936bae7632..7d952e102f 100644 --- a/cuda_core/cuda/core/_memory/_virtual_memory_resource.py +++ b/cuda_core/cuda/core/_memory/_virtual_memory_resource.py @@ -17,10 +17,10 @@ check_or_create_options, driver, ) -from cuda.core._utils.version import binding_version from cuda.core._utils.cuda_utils import ( _check_driver_error as raise_if_driver_error, ) +from cuda.core._utils.version import binding_version __all__ = ["VirtualMemoryResource", "VirtualMemoryResourceOptions"] @@ -100,7 +100,6 @@ class VirtualMemoryResourceOptions: # CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not _allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012 if binding_version() >= (13, 0, 0): - _allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED @staticmethod diff --git a/cuda_core/cuda/core/_program.pyx b/cuda_core/cuda/core/_program.pyx index 5953b162f9..194ef6da53 100644 --- a/cuda_core/cuda/core/_program.pyx +++ b/cuda_core/cuda/core/_program.pyx @@ -13,7 +13,7 @@ from dataclasses import dataclass import threading from warnings import warn -from cuda.bindings import driver, nvrtc +from cuda.bindings import nvrtc from cuda.pathfinder._optional_cuda_import import _optional_cuda_import from libcpp.vector cimport vector diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index e4b086062e..afaa14f134 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -4,7 +4,6 @@ import functools from functools import partial -import importlib.metadata import multiprocessing import platform import warnings diff --git a/cuda_core/tests/test_program.py b/cuda_core/tests/test_program.py index 07e2ed0c94..ac40fb735d 100644 --- a/cuda_core/tests/test_program.py +++ b/cuda_core/tests/test_program.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE +import contextlib import re import warnings @@ -11,7 +12,7 @@ from cuda.core._device import Device from cuda.core._module import Kernel, ObjectCode from cuda.core._program import Program, ProgramOptions -from cuda.core._utils.cuda_utils import CUDAError, driver, handle_return +from cuda.core._utils.cuda_utils import CUDAError, handle_return pytest_plugins = ("cuda_python_test_helpers.nvvm_bitcode",) @@ -33,10 +34,8 @@ def _is_nvvm_available(): not _is_nvvm_available(), reason="NVVM not available (libNVVM not found or cuda-bindings < 12.9.0)" ) -try: - from cuda.core._utils.cuda_utils import driver, handle_return, nvrtc -except Exception: - pass +with contextlib.suppress(Exception): + from cuda.core._utils.cuda_utils import nvrtc def _get_nvrtc_version_for_tests(): From 110d6de9334aedf42e2d0b2d70e85a9aee2f1636 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 25 Mar 2026 14:28:48 -0700 Subject: [PATCH 6/8] Replace _reduce_3_tuple with math.prod in _launcher.pyx Remove the now-dead _reduce_3_tuple helper from cuda_utils.pyx. Made-with: Cursor --- cuda_core/cuda/core/_launcher.pyx | 6 +++--- cuda_core/cuda/core/_utils/cuda_utils.pyx | 4 ---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/cuda_core/cuda/core/_launcher.pyx b/cuda_core/cuda/core/_launcher.pyx index 82ef0777a4..f8189d95ed 100644 --- a/cuda_core/cuda/core/_launcher.pyx +++ b/cuda_core/cuda/core/_launcher.pyx @@ -17,7 +17,7 @@ from cuda.core._utils.cuda_utils cimport ( ) from cuda.core._module import Kernel from cuda.core._stream import Stream -from cuda.core._utils.cuda_utils import _reduce_3_tuple +from math import prod def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args): @@ -62,9 +62,9 @@ cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Str dev = stream.device num_sm = dev.properties.multiprocessor_count max_grid_size = ( - kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm + kernel.occupancy.max_active_blocks_per_multiprocessor(prod(config.block), config.shmem_size) * num_sm ) - if _reduce_3_tuple(config.grid) > max_grid_size: + if prod(config.grid) > max_grid_size: # For now let's try not to be smart and adjust the grid size behind users' back. # We explicitly ask users to adjust. x, y, z = config.grid diff --git a/cuda_core/cuda/core/_utils/cuda_utils.pyx b/cuda_core/cuda/core/_utils/cuda_utils.pyx index afaa14f134..867d066ce2 100644 --- a/cuda_core/cuda/core/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/_utils/cuda_utils.pyx @@ -60,10 +60,6 @@ def cast_to_3_tuple(label, cfg): return cfg + (1,) * (3 - len(cfg)) -def _reduce_3_tuple(t: tuple): - return t[0] * t[1] * t[2] - - cdef int HANDLE_RETURN(cydriver.CUresult err) except?-1 nogil: if err != cydriver.CUresult.CUDA_SUCCESS: return _check_driver_error(err) From 024ede198a6a5eed5bac65fe698f8e754b252e28 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 25 Mar 2026 14:39:16 -0700 Subject: [PATCH 7/8] Remove _driver_ver from _linker.pyx; use _use_nvjitlink_backend as guard Initialize _use_nvjitlink_backend to None so it can serve as its own "already decided" sentinel, eliminating the redundant _driver_ver variable and the driver_version() call that was only used to set it. Made-with: Cursor --- cuda_core/cuda/core/_linker.pyx | 11 ++++------- cuda_core/tests/test_optional_dependency_imports.py | 9 +-------- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/cuda_core/cuda/core/_linker.pyx b/cuda_core/cuda/core/_linker.pyx index 6a6606da7a..cde117b1bb 100644 --- a/cuda_core/cuda/core/_linker.pyx +++ b/cuda_core/cuda/core/_linker.pyx @@ -39,7 +39,6 @@ from cuda.core._utils.cuda_utils import ( driver, is_sequence, ) -from cuda.core._utils.version import driver_version ctypedef const char* const_char_ptr ctypedef void* void_ptr @@ -620,9 +619,8 @@ cdef inline void Linker_annotate_error_log(Linker self, object e): # TODO: revisit this treatment for py313t builds _driver = None # populated if nvJitLink cannot be used -_driver_ver = None _inited = False -_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver() +_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver() # Input type mappings populated by _lazy_init() with C-level enum ints. _nvjitlink_input_types = None @@ -637,12 +635,10 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool: # Note: this function is reused in the tests def _decide_nvjitlink_or_driver() -> bool: """Return True if falling back to the cuLink* driver APIs.""" - global _driver_ver, _driver, _use_nvjitlink_backend - if _driver_ver is not None: + global _driver, _use_nvjitlink_backend + if _use_nvjitlink_backend is not None: return not _use_nvjitlink_backend - _driver_ver = driver_version()[:2] - warn_txt_common = ( "the driver APIs will be used instead, which do not support" " minor version compatibility or linking LTO IRs." @@ -667,6 +663,7 @@ def _decide_nvjitlink_or_driver() -> bool: ) warn(warn_txt, stacklevel=2, category=RuntimeWarning) + _use_nvjitlink_backend = False _driver = driver return True diff --git a/cuda_core/tests/test_optional_dependency_imports.py b/cuda_core/tests/test_optional_dependency_imports.py index 281f2b707d..730c6e7834 100644 --- a/cuda_core/tests/test_optional_dependency_imports.py +++ b/cuda_core/tests/test_optional_dependency_imports.py @@ -12,23 +12,20 @@ def restore_optional_import_state(): saved_nvvm_module = _program._nvvm_module saved_nvvm_attempted = _program._nvvm_import_attempted saved_driver = _linker._driver - saved_driver_ver = _linker._driver_ver saved_inited = _linker._inited saved_use_nvjitlink = _linker._use_nvjitlink_backend _program._nvvm_module = None _program._nvvm_import_attempted = False _linker._driver = None - _linker._driver_ver = None _linker._inited = False - _linker._use_nvjitlink_backend = False + _linker._use_nvjitlink_backend = None yield _program._nvvm_module = saved_nvvm_module _program._nvvm_import_attempted = saved_nvvm_attempted _linker._driver = saved_driver - _linker._driver_ver = saved_driver_ver _linker._inited = saved_inited _linker._use_nvjitlink_backend = saved_use_nvjitlink @@ -79,8 +76,6 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch): - monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0)) - def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" assert probe_function is not None @@ -96,8 +91,6 @@ def fake__optional_cuda_import(modname, probe_function=None): def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch): - monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0)) - def fake__optional_cuda_import(modname, probe_function=None): assert modname == "cuda.bindings.nvjitlink" assert probe_function is not None From af3d27581ea68a073c80f0cdc16ce4480ff96130 Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Wed, 25 Mar 2026 15:03:18 -0700 Subject: [PATCH 8/8] Add return type annotations to version.pyx; fix minor arithmetic Add -> tuple[int, int, int] annotations to binding_version and driver_version. Align driver_version arithmetic with _system.pyx. Made-with: Cursor --- cuda_core/cuda/core/_utils/version.pyx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cuda_core/cuda/core/_utils/version.pyx b/cuda_core/cuda/core/_utils/version.pyx index a7f764e66b..09ea585242 100644 --- a/cuda_core/cuda/core/_utils/version.pyx +++ b/cuda_core/cuda/core/_utils/version.pyx @@ -9,7 +9,7 @@ from cuda.core._utils.cuda_utils import driver, handle_return @functools.cache -def binding_version(): +def binding_version() -> tuple[int, int, int]: """Return the cuda-bindings version as a (major, minor, patch) triple.""" try: parts = importlib.metadata.version("cuda-bindings").split(".")[:3] @@ -19,10 +19,10 @@ def binding_version(): @functools.cache -def driver_version(): +def driver_version() -> tuple[int, int, int]: """Return the CUDA driver version as a (major, minor, patch) triple.""" cdef int ver = handle_return(driver.cuDriverGetVersion()) - return (ver // 1000, (ver % 1000) // 10, ver % 10) + return (ver // 1000, (ver // 10) % 100, ver % 10) cdef tuple _cached_binding_version = None