Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cuda_core/tests/graph/test_device_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import pytest
from helpers.marks import requires_module

from cuda.core import (
Device,
Expand Down Expand Up @@ -75,7 +76,7 @@ def _compile_device_launcher_kernel():
Device().compute_capability.major < 9,
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_device_launch_basic(init_cuda):
"""Test basic device-side graph launch functionality.

Expand Down Expand Up @@ -127,7 +128,7 @@ def test_device_launch_basic(init_cuda):
Device().compute_capability.major < 9,
reason="Device-side graph launch requires Hopper (sm_90+) architecture",
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_device_launch_multiple(init_cuda):
"""Test that device-side graph launch can be executed multiple times.

Expand Down
3 changes: 2 additions & 1 deletion cuda_core/tests/graph/test_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from helpers.graph_kernels import compile_common_kernels
from helpers.marks import requires_module

from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch

Expand Down Expand Up @@ -116,7 +117,7 @@ def test_graph_is_join_required(init_cuda):
gb.end_building().complete()


@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_repeat_capture(init_cuda):
mod = compile_common_kernels()
add_one = mod.get_kernel("add_one")
Expand Down
9 changes: 5 additions & 4 deletions cuda_core/tests/graph/test_graph_builder_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
import numpy as np
import pytest
from helpers.graph_kernels import compile_conditional_kernels
from helpers.marks import requires_module

from cuda.core import Device, GraphBuilder, LaunchConfig, LegacyPinnedMemoryResource, launch


@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_conditional_if(init_cuda, condition_value):
mod = compile_conditional_kernels(type(condition_value))
add_one = mod.get_kernel("add_one")
Expand Down Expand Up @@ -79,7 +80,7 @@ def test_graph_conditional_if(init_cuda, condition_value):
@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False), 1, 0]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_conditional_if_else(init_cuda, condition_value):
mod = compile_conditional_kernels(type(condition_value))
add_one = mod.get_kernel("add_one")
Expand Down Expand Up @@ -151,7 +152,7 @@ def test_graph_conditional_if_else(init_cuda, condition_value):


@pytest.mark.parametrize("condition_value", [0, 1, 2, 3])
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_conditional_switch(init_cuda, condition_value):
mod = compile_conditional_kernels(type(condition_value))
add_one = mod.get_kernel("add_one")
Expand Down Expand Up @@ -242,7 +243,7 @@ def test_graph_conditional_switch(init_cuda, condition_value):


@pytest.mark.parametrize("condition_value", [True, False, 1, 0])
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_conditional_while(init_cuda, condition_value):
mod = compile_conditional_kernels(type(condition_value))
add_one = mod.get_kernel("add_one")
Expand Down
5 changes: 3 additions & 2 deletions cuda_core/tests/graph/test_graph_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
import numpy as np
import pytest
from helpers.graph_kernels import compile_common_kernels, compile_conditional_kernels
from helpers.marks import requires_module

from cuda.core import Device, LaunchConfig, LegacyPinnedMemoryResource, launch
from cuda.core._graph._graph_def import GraphDef
from cuda.core._utils.cuda_utils import CUDAError


@pytest.mark.parametrize("builder", ["GraphBuilder", "GraphDef"])
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_update_kernel_args(init_cuda, builder):
"""Update redirects a kernel to write to a different pointer."""
mod = compile_common_kernels()
Expand Down Expand Up @@ -59,7 +60,7 @@ def build(ptr):
b.close()


@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_graph_update_conditional(init_cuda):
"""Update swaps conditional switch graphs with matching topology."""
mod = compile_conditional_kernels(int)
Expand Down
45 changes: 45 additions & 0 deletions cuda_core/tests/helpers/marks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

"""Reusable pytest marks for cuda_core tests."""

import inspect

import pytest


def requires_module(module, *args, **kwargs):
"""Skip the test if a module is missing or older than required.

Thin wrapper around :func:`pytest.importorskip`. The first argument
may be a module object or a string; all remaining positional and
keyword arguments (``minversion``, ``reason``, ``exc_type``) are
forwarded.

Prefer this over ``pytest.importorskip`` when:

- You need finer granularity than module scope or a test body; this
mark can decorate classes, individual tests, or ``pytest.param`` entries.
- You want to skip before fixtures run, avoiding setup costs.
- The module is already imported and you want to pass it directly.

Usage::

@requires_module("numpy", "2.1")
def test_foo(): ...


@requires_module(np, minversion="2.1")
def test_bar(): ...
"""
if inspect.ismodule(module):
module = module.__name__
elif not isinstance(module, str):
raise TypeError(f"expected module or string, got {type(module).__name__}")

try:
pytest.importorskip(module, *args, **kwargs)
except pytest.skip.Skipped as exc:
return pytest.mark.skipif(True, reason=str(exc))
else:
return pytest.mark.skipif(False, reason="")
8 changes: 3 additions & 5 deletions cuda_core/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ctypes

import helpers
from helpers.marks import requires_module
from helpers.misc import StreamWrapper

try:
Expand Down Expand Up @@ -190,7 +191,7 @@ def test_launch_invalid_values(init_cuda):


@pytest.mark.parametrize("python_type, cpp_type, init_value", PARAMS)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
@requires_module(np, "2.1")
def test_launch_scalar_argument(python_type, cpp_type, init_value):
dev = Device()
dev.set_current()
Expand Down Expand Up @@ -289,10 +290,7 @@ def test_cooperative_launch():
"device_memory_resource", # kludgy, but can go away after #726 is resolved
pytest.param(
LegacyPinnedMemoryResource,
marks=pytest.mark.skipif(
tuple(int(i) for i in np.__version__.split(".")[:3]) < (2, 2, 5),
reason="need numpy 2.2.5+, numpy GH #28632",
),
marks=requires_module(np, "2.2.5", reason="need numpy 2.2.5+ (numpy GH #28632)"),
),
],
)
Expand Down
5 changes: 2 additions & 3 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ml_dtypes = None
import numpy as np
import pytest
from helpers.marks import requires_module

from cuda.core import Device
from cuda.core._dlpack import DLDeviceType
Expand Down Expand Up @@ -85,9 +86,7 @@ def convert_strides_to_counts(strides, itemsize):
# readonly is fixed recently (numpy/numpy#26501)
pytest.param(
np.frombuffer(b""),
marks=pytest.mark.skipif(
tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+"
),
marks=requires_module(np, "2.1"),
),
),
)
Expand Down
Loading