Extract requires() test mark to eliminate repeated numpy version checks#1844
Extract requires() test mark to eliminate repeated numpy version checks#1844Andy-Jost merged 4 commits intoNVIDIA:mainfrom
Conversation
Add helpers/marks.py with a reusable requires() decorator and replace all inline numpy version skipif patterns across test files. Made-with: Cursor
This comment has been minimized.
This comment has been minimized.
cuda_core/tests/helpers/marks.py
Outdated
| import pytest | ||
|
|
||
|
|
||
| def requires(module, *version): |
There was a problem hiding this comment.
The naming is ok here in the (one) definition, but opaque in the N call sites. Could you please make this need_version (which aligns with the skip message)?
There was a problem hiding this comment.
How about requires_module? Though I think requires("numpy", ... is unmistakably clear.
There was a problem hiding this comment.
When I wrote my comment I had this in mind:
@requires(np, 2, 1)
That's a small head-scratcher at first sight. requires_module makes it better, yes.
What do you think about aligning it with pytest.importorskip like this:
@requires_module(np, minversion="2.1.0")
That way it's totally obvious that we're dealing with a version number, and what we're using it for.
There was a problem hiding this comment.
Yes, I agree aligning minversion with importerskip makes perfect sense. And no problem changing to requires_module.
cuda_core/tests/helpers/marks.py
Outdated
| if isinstance(module, str): | ||
| name = module | ||
| try: | ||
| module = importlib.import_module(name) |
There was a problem hiding this comment.
Is this code being used anywhere (the string to module path)? I don't see it in this PR.
We should add it when we need it, since that's more stuff review and maintain.
There was a problem hiding this comment.
I only updated the NumPy skips in this PR, but I expect we would find uses for the string version.
The string version collapses the current two-step process (try/except import, unguarded import, or importerskip at module scope + pytest.skip elsewhere) into a single step. This is useful when a version check is needed, but the checked module is not used directly.
A prime use case is to simplify cuda.bindings checks, replacing this:
import cuda.bindings
_cuda_major = int(cuda.bindings.__version__.split(".")[0])
requires_cuda_13 = pytest.mark.skipif(
_cuda_major < 13,
reason="ManagedMemoryResource requires CUDA 13.0 or later",
)
with this:
@requires("cuda.bindings", 13)
def test...
cuda_core/tests/helpers/marks.py
Outdated
| parts = module.__version__.split(".")[:n] | ||
| installed = tuple(int(p) for p in parts) | ||
| ver_str = ".".join(str(v) for v in version) | ||
| return pytest.mark.skipif(installed < version, reason=f"need {name} {ver_str}+") |
There was a problem hiding this comment.
Actually, can't we replace this entire thing with this:
pytest.importorskip("numpy", minversion="2.1.0")?
There was a problem hiding this comment.
+1
Much better.
(I didn't know that importorskip has the minversion feature.)
There was a problem hiding this comment.
No, because that operates at the module level. We have several individual tests that require NumPy buried in larger test modules.
Not to say it can't work. We can put importerskip into every test function body and we do that in some places. However, a mark is cheaper, arguably clearer, and amounts to a smaller change to the existing tests.
A mark that runs at collection time can be better because:
- Fixtures are expensive:
init_cudacreates a CUDA context. Withimportorskipinside the body, that context is created and torn down even for a test that will just skip because, e.g.,cupyisn't installed. pytest.parammarks:marks=requires("cupy", 14)works cleanly in parametrize lists. There's no way to putimportorskipinside apytest.param.- Multiple tests sharing the same requirement: one
requires("cupy", 14)per test class is cleaner than callingimportorskipin every body. Or evenpytestmark = requires("cupy", 14)at module level to skip the entire file.
There was a problem hiding this comment.
Could you please try this pattern (I expect it'll work, although I haven't tested it myself):
def test_needs_numpy():
np = pytest.importorskip("numpy", minversion="2.1.0")
There was a problem hiding this comment.
It's not module level. importorskip just raises pytest.Skip (the mechanism used for skip control flow) wherever you call it.
If it's in a test, that'll skip just the test.
If it's at module scope, then it'll skip the whole module.
There was a problem hiding this comment.
On my end, Claude also kept saying that using importerskip to implement requires is awkward, but I just don't buy the argument. Using something that does exactly what we want as a way to avoid reimplementing that thing seems... great.
There was a problem hiding this comment.
FYI, here is the source for importerskip:
def importorskip(
modname: str,
minversion: str | None = None,
reason: str | None = None,
*,
exc_type: type[ImportError] | None = None,
) -> Any:
# ... docstring ...
import warnings
__tracebackhide__ = True
compile(modname, "", "eval") # to catch syntaxerrors
if exc_type is None:
exc_type = ImportError
warn_on_import_error = True
else:
warn_on_import_error = False
skipped: Skipped | None = None
warning: Warning | None = None
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
__import__(modname)
except exc_type as exc:
if reason is None:
reason = f"could not import {modname!r}: {exc}"
skipped = Skipped(reason, allow_module_level=True)
if warn_on_import_error and not isinstance(exc, ModuleNotFoundError):
# ... deprecation warning ...
pass
if warning:
warnings.warn(warning, stacklevel=2)
if skipped:
raise skipped
mod = sys.modules[modname]
if minversion is None:
return mod
verattr = getattr(mod, "__version__", None)
if minversion is not None:
from packaging.version import Version
if verattr is None or Version(verattr) < Version(minversion):
raise Skipped(
f"module {modname!r} has __version__ {verattr!r}, required is: {minversion!r}",
allow_module_level=True,
)
return modThere was a problem hiding this comment.
Sounds good. Ultimately my main concern is that we land on an intuitive (at the call sites; for humans) API. If we unexpectedly run into issues with the implementation, we can tweak then. Using the compact version sounds fine.
There was a problem hiding this comment.
if warn_on_import_error and not isinstance(exc, ModuleNotFoundError):
# ... deprecation warning ...
pass
Did you edit it for April Fools'?
There was a problem hiding this comment.
Ultimately my main concern is that we land on an intuitive (at the call sites; for humans) API
I'm glad we went with requires_module because after looking through other skips I already want to add requires_compute_capability and more.
Did you edit it for April Fools'?
I should have! That's just a comment in place of the code that issues a deprecation warning. The default exc_type will change from ImportError to ModuleNotFoundError in pytest version 9.1
Rename the mark to requires_module and reimplement it as a thin wrapper around pytest.importorskip, forwarding *args/**kwargs directly. Version arguments are now strings (matching importorskip's minversion parameter) rather than integer tuples. Update all call sites accordingly. Made-with: Cursor
Made-with: Cursor
…y-mark Made-with: Cursor # Conflicts: # cuda_core/tests/graph/test_graph_update.py
|
Summary
requires_module(module, *args, **kwargs)pytest mark intests/helpers/marks.pythat skips tests when a module is missing or too oldpytest.mark.skipif(tuple(int(i) for i in np.__version__...patterns across 6 test filesUPDATE: Renamed from
requirestorequires_moduleand reworked the implementation to delegate topytest.importorskip. Arguments now matchimportorskip's interface (minversionas a string, e.g."2.1") rather than an integer version tuple.Changes
tests/helpers/marks.py:requires_module()accepts a module object or string name, forwards remaining args/kwargs topytest.importorskip, and returns askipifmarktest_advanced.py,test_basic.py,test_conditional.py,test_device_launch.py,test_launcher.py,test_utils.pyTest plan