Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cuda_bindings/cuda/bindings/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
38 changes: 38 additions & 0 deletions cuda_bindings/tests/test_enum_explanations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import importlib
import importlib.metadata

import pytest

from cuda.bindings import driver, runtime

_EXPLANATION_MODULES = [
("driver_cu_result_explanations", "DRIVER_CU_RESULT_EXPLANATIONS", driver.CUresult),
("runtime_cuda_error_explanations", "RUNTIME_CUDA_ERROR_EXPLANATIONS", runtime.cudaError_t),
]


def _get_binding_version():
try:
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
except importlib.metadata.PackageNotFoundError:
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
return tuple(int(v) for v in major_minor)


@pytest.mark.parametrize("module_name,dict_name,enum_type", _EXPLANATION_MODULES)
def test_explanations_health(module_name, dict_name, enum_type):
mod = importlib.import_module(f"cuda.bindings._utils.{module_name}")
expl_dict = getattr(mod, dict_name)

known_codes = set()
for error in enum_type:
code = int(error)
assert code in expl_dict
known_codes.add(code)

if _get_binding_version() >= (13, 0):
extra_expl = sorted(set(expl_dict.keys()) - known_codes)
assert not extra_expl
11 changes: 9 additions & 2 deletions cuda_core/cuda/core/_utils/cuda_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,15 @@ from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release, Py_buffer, PyB

from cuda.bindings cimport cynvrtc, cynvvm, cynvjitlink

from cuda.core._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS
from cuda.core._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS
try:
from cuda.bindings._utils.driver_cu_result_explanations import DRIVER_CU_RESULT_EXPLANATIONS
except ModuleNotFoundError:
DRIVER_CU_RESULT_EXPLANATIONS = {}

try:
from cuda.bindings._utils.runtime_cuda_error_explanations import RUNTIME_CUDA_ERROR_EXPLANATIONS
except ModuleNotFoundError:
RUNTIME_CUDA_ERROR_EXPLANATIONS = {}


class CUDAError(Exception):
Expand Down
36 changes: 0 additions & 36 deletions cuda_core/tests/test_cuda_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,6 @@
from cuda.core._utils.clear_error_support import assert_type_str_or_bytes_like, raise_code_path_meant_to_be_unreachable


def test_driver_cu_result_explanations_health():
expl_dict = cuda_utils.DRIVER_CU_RESULT_EXPLANATIONS

# Ensure all CUresult enums are in expl_dict
known_codes = set()
for error in driver.CUresult:
code = int(error)
assert code in expl_dict
known_codes.add(code)

from cuda.core._utils.version import binding_version

if binding_version() >= (13, 0, 0):
# Ensure expl_dict has no codes not known as a CUresult enum
extra_expl = sorted(set(expl_dict.keys()) - known_codes)
assert not extra_expl


def test_runtime_cuda_error_explanations_health():
expl_dict = cuda_utils.RUNTIME_CUDA_ERROR_EXPLANATIONS

# Ensure all cudaError_t enums are in expl_dict
known_codes = set()
for error in runtime.cudaError_t:
code = int(error)
assert code in expl_dict
known_codes.add(code)

from cuda.core._utils.version import binding_version

if binding_version() >= (13, 0, 0):
# Ensure expl_dict has no codes not known as a cudaError_t enum
extra_expl = sorted(set(expl_dict.keys()) - known_codes)
assert not extra_expl


def test_check_driver_error():
num_unexpected = 0
for error in driver.CUresult:
Expand Down
Loading