From d8a20311b609b75274be0e03759c8044f7430710 Mon Sep 17 00:00:00 2001 From: Keith Kraus Date: Tue, 28 Apr 2026 12:18:26 -0400 Subject: [PATCH 1/2] Add CUDA process checkpointing helpers --- cuda_core/cuda/core/__init__.py | 2 +- cuda_core/cuda/core/checkpoint.py | 182 ++++++++++++++++ cuda_core/docs/source/api.rst | 16 ++ cuda_core/docs/source/release/1.0.0-notes.rst | 5 +- cuda_core/tests/test_checkpoint.py | 203 ++++++++++++++++++ 5 files changed, 406 insertions(+), 2 deletions(-) create mode 100644 cuda_core/cuda/core/checkpoint.py create mode 100644 cuda_core/tests/test_checkpoint.py diff --git a/cuda_core/cuda/core/__init__.py b/cuda_core/cuda/core/__init__.py index dfd52accea3..3152c9ceacf 100644 --- a/cuda_core/cuda/core/__init__.py +++ b/cuda_core/cuda/core/__init__.py @@ -28,7 +28,7 @@ def _import_versioned_module(): del _import_versioned_module -from cuda.core import system, utils +from cuda.core import checkpoint, system, utils from cuda.core._device import Device from cuda.core._event import Event, EventOptions from cuda.core._graphics import GraphicsResource diff --git a/cuda_core/cuda/core/checkpoint.py b/cuda_core/cuda/core/checkpoint.py new file mode 100644 index 00000000000..ad0be778974 --- /dev/null +++ b/cuda_core/cuda/core/checkpoint.py @@ -0,0 +1,182 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from collections.abc import Mapping as _Mapping +from dataclasses import dataclass as _dataclass +from enum import IntEnum as _IntEnum +from typing import Any as _Any + +from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return + +try: + from cuda.bindings import driver as _driver +except ImportError: + from cuda import cuda as _driver + + +class ProcessState(_IntEnum): + """ + CUDA checkpoint state for a process. + """ + + RUNNING = 0 + LOCKED = 1 + CHECKPOINTED = 2 + FAILED = 3 + + +@_dataclass(frozen=True) +class Process: + """ + CUDA process that can be locked, checkpointed, restored, and unlocked. + + Parameters + ---------- + pid : int + Process ID of the CUDA process. + """ + + pid: int + + def __post_init__(self): + _check_pid(self.pid) + + @property + def state(self) -> ProcessState: + """ + CUDA checkpoint state for this process. + """ + driver = _get_driver() + state = _handle_return(driver, driver.cuCheckpointProcessGetState(self.pid)) + return ProcessState(int(state)) + + @property + def restore_thread_id(self) -> int: + """ + CUDA restore thread ID for this process. + """ + driver = _get_driver() + return _handle_return(driver, driver.cuCheckpointProcessGetRestoreThreadId(self.pid)) + + def lock(self, timeout_ms: int = 0) -> None: + """ + Lock this process, blocking further CUDA API calls. + + Parameters + ---------- + timeout_ms : int, optional + Timeout in milliseconds. A value of 0 indicates no timeout. + """ + driver = _get_driver() + args = driver.CUcheckpointLockArgs() + args.timeoutMs = _check_timeout_ms(timeout_ms) + _handle_return(driver, driver.cuCheckpointProcessLock(self.pid, args)) + + def checkpoint(self) -> None: + """ + Checkpoint the GPU memory contents of this locked process. + """ + driver = _get_driver() + _handle_return(driver, driver.cuCheckpointProcessCheckpoint(self.pid, None)) + + def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: + """ + Restore this checkpointed process. + + Parameters + ---------- + gpu_mapping : mapping, optional + GPU UUID remapping from each checkpointed GPU UUID to the GPU UUID + to restore onto. If provided, the mapping must contain every + checkpointed GPU UUID. + """ + driver = _get_driver() + args = _make_restore_args(driver, gpu_mapping) + _handle_return(driver, driver.cuCheckpointProcessRestore(self.pid, args)) + + def unlock(self) -> None: + """ + Unlock this locked process so it can resume CUDA API calls. + """ + driver = _get_driver() + _handle_return(driver, driver.cuCheckpointProcessUnlock(self.pid, None)) + + +def _get_driver(): + required = ( + "cuCheckpointProcessCheckpoint", + "cuCheckpointProcessGetRestoreThreadId", + "cuCheckpointProcessGetState", + "cuCheckpointProcessLock", + "cuCheckpointProcessRestore", + "cuCheckpointProcessUnlock", + "CUcheckpointGpuPair", + "CUcheckpointLockArgs", + "CUcheckpointRestoreArgs", + ) + missing = [name for name in required if not hasattr(_driver, name)] + if missing: + raise RuntimeError( + f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" + ) + return _driver + + +def _handle_return(driver, result): + err = result[0] + not_supported_errors = ( + getattr(driver.CUresult, "CUDA_ERROR_NOT_FOUND", None), + getattr(driver.CUresult, "CUDA_ERROR_NOT_SUPPORTED", None), + ) + if err in not_supported_errors: + raise RuntimeError( + "CUDA checkpointing is not supported by the installed NVIDIA driver. " + "Upgrade to a driver version with CUDA checkpoint API support." + ) + + return _handle_cuda_return(result) + + +def _check_pid(pid: int) -> int: + if isinstance(pid, bool) or not isinstance(pid, int): + raise TypeError("pid must be an int") + if pid <= 0: + raise ValueError("pid must be a positive int") + return pid + + +def _check_timeout_ms(timeout_ms: int) -> int: + if isinstance(timeout_ms, bool) or not isinstance(timeout_ms, int): + raise TypeError("timeout_ms must be an int") + if timeout_ms < 0: + raise ValueError("timeout_ms must be >= 0") + return timeout_ms + + +def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): + if gpu_mapping is None: + return None + if not isinstance(gpu_mapping, _Mapping): + raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID") + + pairs = [] + for old_uuid, new_uuid in gpu_mapping.items(): + pair = driver.CUcheckpointGpuPair() + pair.oldUuid = old_uuid + pair.newUuid = new_uuid + pairs.append(pair) + + if not pairs: + return None + + args = driver.CUcheckpointRestoreArgs() + args.gpuPairs = pairs + args.gpuPairsCount = len(pairs) + return args + + +__all__ = [ + "Process", + "ProcessState", +] diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 88780732d54..5d7efdb2d17 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -174,6 +174,22 @@ CUDA compilation toolchain LinkerOptions +CUDA process checkpointing +-------------------------- + +.. autosummary:: + :toctree: generated/ + + :template: class.rst + + checkpoint.Process + +.. autosummary:: + :toctree: generated/ + + checkpoint.ProcessState + + CUDA system information and NVIDIA Management Library (NVML) ------------------------------------------------------------ diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst index 34eff571005..13e1430ee23 100644 --- a/cuda_core/docs/source/release/1.0.0-notes.rst +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -16,7 +16,10 @@ Highlights New features ------------ -- TBD +- Added the :mod:`cuda.core.checkpoint` module for CUDA process checkpointing, + including process state queries, lock/checkpoint/restore/unlock operations, + and GPU UUID remapping support for restore. + (`#1343 `__) Fixes and enhancements diff --git a/cuda_core/tests/test_checkpoint.py b/cuda_core/tests/test_checkpoint.py new file mode 100644 index 00000000000..c461fb6e6e0 --- /dev/null +++ b/cuda_core/tests/test_checkpoint.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from enum import IntEnum + +import pytest + +from cuda.core import checkpoint + + +class _DriverProcessState(IntEnum): + CU_PROCESS_STATE_RUNNING = 0 + CU_PROCESS_STATE_LOCKED = 1 + CU_PROCESS_STATE_CHECKPOINTED = 2 + CU_PROCESS_STATE_FAILED = 3 + + +class _DriverResult(IntEnum): + CUDA_SUCCESS = 0 + CUDA_ERROR_NOT_FOUND = 500 + CUDA_ERROR_NOT_SUPPORTED = 801 + + +class _Uuid: + pass + + +class _CheckpointGpuPair: + def __init__(self): + self.oldUuid = None + self.newUuid = None + + +class _CheckpointLockArgs: + def __init__(self): + self.timeoutMs = None + + +class _CheckpointRestoreArgs: + def __init__(self): + self.gpuPairs = None + self.gpuPairsCount = None + + +class _Driver: + CUresult = _DriverResult + CUprocessState = _DriverProcessState + CUcheckpointGpuPair = _CheckpointGpuPair + CUcheckpointLockArgs = _CheckpointLockArgs + CUcheckpointRestoreArgs = _CheckpointRestoreArgs + + def __init__(self): + self.calls = [] + + def cuCheckpointProcessGetState(self, pid): + self.calls.append(("get_state", pid)) + return (0, self.CUprocessState.CU_PROCESS_STATE_CHECKPOINTED) + + def cuCheckpointProcessGetRestoreThreadId(self, pid): + self.calls.append(("get_restore_thread_id", pid)) + return (0, 123) + + def cuCheckpointProcessLock(self, pid, args): + self.calls.append(("lock", pid, args)) + return (0,) + + def cuCheckpointProcessCheckpoint(self, pid, args): + self.calls.append(("checkpoint", pid, args)) + return (0,) + + def cuCheckpointProcessRestore(self, pid, args): + self.calls.append(("restore", pid, args)) + return (0,) + + def cuCheckpointProcessUnlock(self, pid, args): + self.calls.append(("unlock", pid, args)) + return (0,) + + +@pytest.fixture +def checkpoint_driver(monkeypatch): + driver = _Driver() + monkeypatch.setattr(checkpoint, "_get_driver", lambda: driver) + + def handle_return(driver, result): + if len(result) == 1: + return None + return result[1] + + monkeypatch.setattr(checkpoint, "_handle_return", handle_return) + return driver + + +def test_public_checkpoint_symbols(): + assert checkpoint.ProcessState.CHECKPOINTED == 2 + assert "Process" in checkpoint.__all__ + assert "ProcessState" in checkpoint.__all__ + for name in ("Any", "Mapping", "IntEnum", "dataclass", "handle_return"): + assert not hasattr(checkpoint, name) + + +def test_process_state(checkpoint_driver): + state = checkpoint.Process(42).state + + assert state is checkpoint.ProcessState.CHECKPOINTED + assert checkpoint_driver.calls == [("get_state", 42)] + + +def test_process_restore_thread_id(checkpoint_driver): + tid = checkpoint.Process(42).restore_thread_id + + assert tid == 123 + assert checkpoint_driver.calls == [("get_restore_thread_id", 42)] + + +def test_process_lock_sets_timeout_ms(checkpoint_driver): + checkpoint.Process(42).lock(timeout_ms=500) + + opname, pid, args = checkpoint_driver.calls[0] + assert opname == "lock" + assert pid == 42 + assert isinstance(args, _CheckpointLockArgs) + assert args.timeoutMs == 500 + + +def test_process_checkpoint_and_unlock_pass_null_args(checkpoint_driver): + process = checkpoint.Process(42) + process.checkpoint() + process.unlock() + + assert checkpoint_driver.calls == [ + ("checkpoint", 42, None), + ("unlock", 42, None), + ] + + +def test_process_restore_accepts_gpu_uuid_mapping(checkpoint_driver): + old_uuid = _Uuid() + new_uuid = _Uuid() + + checkpoint.Process(42).restore(gpu_mapping={old_uuid: new_uuid}) + + opname, pid, args = checkpoint_driver.calls[0] + assert opname == "restore" + assert pid == 42 + assert isinstance(args, _CheckpointRestoreArgs) + assert args.gpuPairsCount == 1 + assert len(args.gpuPairs) == 1 + assert args.gpuPairs[0].oldUuid is old_uuid + assert args.gpuPairs[0].newUuid is new_uuid + + +def test_process_restore_empty_gpu_mapping_uses_null_args(checkpoint_driver): + checkpoint.Process(42).restore(gpu_mapping={}) + + assert checkpoint_driver.calls == [("restore", 42, None)] + + +@pytest.mark.parametrize( + ("args", "error_type", "match"), + [ + (("123",), TypeError, "pid must be an int"), + ((True,), TypeError, "pid must be an int"), + ((0,), ValueError, "pid must be a positive int"), + ], +) +def test_process_rejects_invalid_pid(checkpoint_driver, args, error_type, match): + with pytest.raises(error_type, match=match): + checkpoint.Process(*args) + + +@pytest.mark.parametrize( + ("timeout_ms", "error_type", "match"), + [ + (-1, ValueError, "timeout_ms must be >= 0"), + (1.5, TypeError, "timeout_ms must be an int"), + (True, TypeError, "timeout_ms must be an int"), + ], +) +def test_process_lock_rejects_invalid_timeout(checkpoint_driver, timeout_ms, error_type, match): + with pytest.raises(error_type, match=match): + checkpoint.Process(42).lock(timeout_ms=timeout_ms) + + +def test_process_restore_rejects_invalid_gpu_mapping(checkpoint_driver): + with pytest.raises(TypeError, match="gpu_mapping must be a mapping"): + checkpoint.Process(42).restore(gpu_mapping=[object()]) + + +@pytest.mark.parametrize( + "error_name", + [ + "CUDA_ERROR_NOT_FOUND", + "CUDA_ERROR_NOT_SUPPORTED", + ], +) +def test_checkpoint_apis_reject_unsupported_driver(error_name): + driver = _Driver() + result = (getattr(driver.CUresult, error_name),) + + with pytest.raises(RuntimeError, match="CUDA checkpointing is not supported"): + checkpoint._handle_return(driver, result) From 4992921dfe046ad1b2beabcece1261e5f6686c0d Mon Sep 17 00:00:00 2001 From: Keith Kraus Date: Wed, 29 Apr 2026 15:40:27 -0400 Subject: [PATCH 2/2] Address checkpoint review feedback --- cuda_core/cuda/core/checkpoint.py | 113 ++++++++++++------ cuda_core/docs/source/api.rst | 35 +++++- cuda_core/docs/source/release/1.0.0-notes.rst | 4 +- cuda_core/tests/test_checkpoint.py | 102 +++++++++++++--- 4 files changed, 197 insertions(+), 57 deletions(-) diff --git a/cuda_core/cuda/core/checkpoint.py b/cuda_core/cuda/core/checkpoint.py index ad0be778974..1333a8f0e43 100644 --- a/cuda_core/cuda/core/checkpoint.py +++ b/cuda_core/cuda/core/checkpoint.py @@ -3,11 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 from collections.abc import Mapping as _Mapping -from dataclasses import dataclass as _dataclass -from enum import IntEnum as _IntEnum from typing import Any as _Any +from typing import Literal as _Literal from cuda.core._utils.cuda_utils import handle_return as _handle_cuda_return +from cuda.core._utils.version import binding_version as _binding_version +from cuda.core._utils.version import driver_version as _driver_version try: from cuda.bindings import driver as _driver @@ -15,18 +16,30 @@ from cuda import cuda as _driver -class ProcessState(_IntEnum): - """ - CUDA checkpoint state for a process. - """ +ProcessStateT = _Literal["running", "locked", "checkpointed", "failed"] + +_PROCESS_STATE_NAMES: dict[int, ProcessStateT] = { + 0: "running", + 1: "locked", + 2: "checkpointed", + 3: "failed", +} - RUNNING = 0 - LOCKED = 1 - CHECKPOINTED = 2 - FAILED = 3 +_REQUIRED_BINDING_ATTRS = ( + "cuCheckpointProcessCheckpoint", + "cuCheckpointProcessGetRestoreThreadId", + "cuCheckpointProcessGetState", + "cuCheckpointProcessLock", + "cuCheckpointProcessRestore", + "cuCheckpointProcessUnlock", + "CUcheckpointGpuPair", + "CUcheckpointLockArgs", + "CUcheckpointRestoreArgs", +) +_REQUIRED_DRIVER_VERSION = (12, 8, 0) +_driver_capability_checked = False -@_dataclass(frozen=True) class Process: """ CUDA process that can be locked, checkpointed, restored, and unlocked. @@ -37,19 +50,23 @@ class Process: Process ID of the CUDA process. """ - pid: int + __slots__ = ("pid",) - def __post_init__(self): - _check_pid(self.pid) + def __init__(self, pid: int): + self.pid = _check_pid(pid) @property - def state(self) -> ProcessState: + def state(self) -> ProcessStateT: """ CUDA checkpoint state for this process. """ driver = _get_driver() - state = _handle_return(driver, driver.cuCheckpointProcessGetState(self.pid)) - return ProcessState(int(state)) + state = _call_driver(driver, driver.cuCheckpointProcessGetState, self.pid) + state_value = int(state) + try: + return _PROCESS_STATE_NAMES[state_value] + except KeyError as e: + raise RuntimeError(f"Unknown CUDA checkpoint process state: {state_value}") from e @property def restore_thread_id(self) -> int: @@ -57,7 +74,7 @@ def restore_thread_id(self) -> int: CUDA restore thread ID for this process. """ driver = _get_driver() - return _handle_return(driver, driver.cuCheckpointProcessGetRestoreThreadId(self.pid)) + return _call_driver(driver, driver.cuCheckpointProcessGetRestoreThreadId, self.pid) def lock(self, timeout_ms: int = 0) -> None: """ @@ -71,14 +88,14 @@ def lock(self, timeout_ms: int = 0) -> None: driver = _get_driver() args = driver.CUcheckpointLockArgs() args.timeoutMs = _check_timeout_ms(timeout_ms) - _handle_return(driver, driver.cuCheckpointProcessLock(self.pid, args)) + _call_driver(driver, driver.cuCheckpointProcessLock, self.pid, args) def checkpoint(self) -> None: """ Checkpoint the GPU memory contents of this locked process. """ driver = _get_driver() - _handle_return(driver, driver.cuCheckpointProcessCheckpoint(self.pid, None)) + _call_driver(driver, driver.cuCheckpointProcessCheckpoint, self.pid, None) def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: """ @@ -93,36 +110,63 @@ def restore(self, gpu_mapping: _Mapping[_Any, _Any] | None = None) -> None: """ driver = _get_driver() args = _make_restore_args(driver, gpu_mapping) - _handle_return(driver, driver.cuCheckpointProcessRestore(self.pid, args)) + _call_driver(driver, driver.cuCheckpointProcessRestore, self.pid, args) def unlock(self) -> None: """ Unlock this locked process so it can resume CUDA API calls. """ driver = _get_driver() - _handle_return(driver, driver.cuCheckpointProcessUnlock(self.pid, None)) + _call_driver(driver, driver.cuCheckpointProcessUnlock, self.pid, None) def _get_driver(): - required = ( - "cuCheckpointProcessCheckpoint", - "cuCheckpointProcessGetRestoreThreadId", - "cuCheckpointProcessGetState", - "cuCheckpointProcessLock", - "cuCheckpointProcessRestore", - "cuCheckpointProcessUnlock", - "CUcheckpointGpuPair", - "CUcheckpointLockArgs", - "CUcheckpointRestoreArgs", - ) - missing = [name for name in required if not hasattr(_driver, name)] + global _driver_capability_checked + if _driver_capability_checked: + return _driver + + binding_ver = _binding_version() + if not _binding_version_supports_checkpoint(binding_ver): + raise RuntimeError( + "CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. " + f"Found cuda.bindings {'.'.join(str(part) for part in binding_ver[:3])}." + ) + + missing = [name for name in _REQUIRED_BINDING_ATTRS if not hasattr(_driver, name)] if missing: raise RuntimeError( f"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Missing: {', '.join(missing)}" ) + + driver_ver = _driver_version() + if driver_ver < _REQUIRED_DRIVER_VERSION: + raise RuntimeError( + "CUDA checkpointing is not supported by the installed NVIDIA driver. " + "Upgrade to a driver version with CUDA checkpoint API support." + ) + + _driver_capability_checked = True return _driver +def _binding_version_supports_checkpoint(version) -> bool: + major, minor, patch = version[:3] + return (major == 12 and (minor, patch) >= (8, 0)) or (major == 13 and (minor, patch) >= (0, 2)) or major > 13 + + +def _call_driver(driver, func, *args): + try: + result = func(*args) + except RuntimeError as e: + if "cuCheckpointProcess" in str(e) and "not found" in str(e): + raise RuntimeError( + "CUDA checkpointing is not supported by the installed NVIDIA driver. " + "Upgrade to a driver version with CUDA checkpoint API support." + ) from e + raise + return _handle_return(driver, result) + + def _handle_return(driver, result): err = result[0] not_supported_errors = ( @@ -178,5 +222,4 @@ def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None): __all__ = [ "Process", - "ProcessState", ] diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 5d7efdb2d17..2762f8ca541 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -177,17 +177,42 @@ CUDA compilation toolchain CUDA process checkpointing -------------------------- -.. autosummary:: - :toctree: generated/ +The :mod:`cuda.core.checkpoint` module wraps the CUDA driver process +checkpoint APIs. These APIs are intended for Linux process checkpoint and +restore workflows, and require a CUDA driver with checkpoint API support and +a ``cuda-bindings`` version that exposes those driver entry points. - :template: class.rst +A checkpoint workflow operates on a CUDA process by process ID. The typical +sequence is to lock the process, capture its GPU memory state, restore it +when needed, and then unlock it so CUDA API calls can resume: - checkpoint.Process +.. code-block:: python + + from cuda.core import checkpoint + + process = checkpoint.Process(pid) + process.lock(timeout_ms=5000) + process.checkpoint() + process.restore() + process.unlock() + +``Process.state`` returns one of ``"running"``, ``"locked"``, +``"checkpointed"``, or ``"failed"``. Restore may optionally remap GPUs by +passing ``gpu_mapping`` from each checkpointed GPU UUID to the GPU UUID that +should be used during restore. A successful restore returns the process to +the locked state; call ``Process.unlock`` after restore to allow CUDA API +calls to resume. + +The CUDA driver requires restore to run from the process restore thread. +Use ``Process.restore_thread_id`` to discover that thread before calling +``Process.restore`` from a checkpoint coordinator. .. autosummary:: :toctree: generated/ - checkpoint.ProcessState + :template: class.rst + + checkpoint.Process CUDA system information and NVIDIA Management Library (NVML) diff --git a/cuda_core/docs/source/release/1.0.0-notes.rst b/cuda_core/docs/source/release/1.0.0-notes.rst index 13e1430ee23..f5d3645c3d6 100644 --- a/cuda_core/docs/source/release/1.0.0-notes.rst +++ b/cuda_core/docs/source/release/1.0.0-notes.rst @@ -17,8 +17,8 @@ New features ------------ - Added the :mod:`cuda.core.checkpoint` module for CUDA process checkpointing, - including process state queries, lock/checkpoint/restore/unlock operations, - and GPU UUID remapping support for restore. + including string process state queries, lock/checkpoint/restore/unlock + operations, and GPU UUID remapping support for restore. (`#1343 `__) diff --git a/cuda_core/tests/test_checkpoint.py b/cuda_core/tests/test_checkpoint.py index c461fb6e6e0..d92a0c632ab 100644 --- a/cuda_core/tests/test_checkpoint.py +++ b/cuda_core/tests/test_checkpoint.py @@ -9,14 +9,14 @@ from cuda.core import checkpoint -class _DriverProcessState(IntEnum): +class _MockDriverProcessState(IntEnum): CU_PROCESS_STATE_RUNNING = 0 CU_PROCESS_STATE_LOCKED = 1 CU_PROCESS_STATE_CHECKPOINTED = 2 CU_PROCESS_STATE_FAILED = 3 -class _DriverResult(IntEnum): +class _MockDriverResult(IntEnum): CUDA_SUCCESS = 0 CUDA_ERROR_NOT_FOUND = 500 CUDA_ERROR_NOT_SUPPORTED = 801 @@ -43,19 +43,20 @@ def __init__(self): self.gpuPairsCount = None -class _Driver: - CUresult = _DriverResult - CUprocessState = _DriverProcessState +class _MockDriver: + CUresult = _MockDriverResult + CUprocessState = _MockDriverProcessState CUcheckpointGpuPair = _CheckpointGpuPair CUcheckpointLockArgs = _CheckpointLockArgs CUcheckpointRestoreArgs = _CheckpointRestoreArgs - def __init__(self): + def __init__(self, process_state=_MockDriverProcessState.CU_PROCESS_STATE_CHECKPOINTED): self.calls = [] + self.process_state = process_state def cuCheckpointProcessGetState(self, pid): self.calls.append(("get_state", pid)) - return (0, self.CUprocessState.CU_PROCESS_STATE_CHECKPOINTED) + return (0, self.process_state) def cuCheckpointProcessGetRestoreThreadId(self, pid): self.calls.append(("get_restore_thread_id", pid)) @@ -80,7 +81,7 @@ def cuCheckpointProcessUnlock(self, pid, args): @pytest.fixture def checkpoint_driver(monkeypatch): - driver = _Driver() + driver = _MockDriver() monkeypatch.setattr(checkpoint, "_get_driver", lambda: driver) def handle_return(driver, result): @@ -93,17 +94,27 @@ def handle_return(driver, result): def test_public_checkpoint_symbols(): - assert checkpoint.ProcessState.CHECKPOINTED == 2 - assert "Process" in checkpoint.__all__ - assert "ProcessState" in checkpoint.__all__ - for name in ("Any", "Mapping", "IntEnum", "dataclass", "handle_return"): + assert set(checkpoint.ProcessStateT.__args__) == {"running", "locked", "checkpointed", "failed"} + assert checkpoint.__all__ == ["Process"] + for name in ("Any", "Mapping", "Literal", "IntEnum", "dataclass", "handle_return", "ProcessState"): assert not hasattr(checkpoint, name) -def test_process_state(checkpoint_driver): +@pytest.mark.parametrize( + ("process_state", "expected"), + [ + (_MockDriverProcessState.CU_PROCESS_STATE_RUNNING, "running"), + (_MockDriverProcessState.CU_PROCESS_STATE_LOCKED, "locked"), + (_MockDriverProcessState.CU_PROCESS_STATE_CHECKPOINTED, "checkpointed"), + (_MockDriverProcessState.CU_PROCESS_STATE_FAILED, "failed"), + ], +) +def test_process_state(checkpoint_driver, process_state, expected): + checkpoint_driver.process_state = process_state + state = checkpoint.Process(42).state - assert state is checkpoint.ProcessState.CHECKPOINTED + assert state == expected assert checkpoint_driver.calls == [("get_state", 42)] @@ -196,8 +207,69 @@ def test_process_restore_rejects_invalid_gpu_mapping(checkpoint_driver): ], ) def test_checkpoint_apis_reject_unsupported_driver(error_name): - driver = _Driver() + driver = _MockDriver() result = (getattr(driver.CUresult, error_name),) with pytest.raises(RuntimeError, match="CUDA checkpointing is not supported"): checkpoint._handle_return(driver, result) + + +def test_get_driver_caches_capability_check(monkeypatch): + calls = {"binding_version": 0, "driver_version": 0} + + def binding_version(): + calls["binding_version"] += 1 + return (13, 0, 2) + + def driver_version(): + calls["driver_version"] += 1 + return (12, 8, 0) + + driver = _MockDriver() + monkeypatch.setattr(checkpoint, "_driver", driver) + monkeypatch.setattr(checkpoint, "_driver_capability_checked", False) + monkeypatch.setattr(checkpoint, "_binding_version", binding_version) + monkeypatch.setattr(checkpoint, "_driver_version", driver_version) + + assert checkpoint._get_driver() is driver + assert checkpoint._get_driver() is driver + assert calls == {"binding_version": 1, "driver_version": 1} + + +@pytest.mark.parametrize("binding_version", [(12, 7, 0), (13, 0, 1)]) +def test_get_driver_rejects_unsupported_binding_version(monkeypatch, binding_version): + monkeypatch.setattr(checkpoint, "_driver", _MockDriver()) + monkeypatch.setattr(checkpoint, "_driver_capability_checked", False) + monkeypatch.setattr(checkpoint, "_binding_version", lambda: binding_version) + + with pytest.raises(RuntimeError, match="CUDA checkpointing requires cuda.bindings"): + checkpoint._get_driver() + + +def test_get_driver_rejects_missing_binding_symbols(monkeypatch): + monkeypatch.setattr(checkpoint, "_driver", object()) + monkeypatch.setattr(checkpoint, "_driver_capability_checked", False) + monkeypatch.setattr(checkpoint, "_binding_version", lambda: (13, 0, 2)) + + with pytest.raises(RuntimeError, match="Missing: cuCheckpointProcessCheckpoint"): + checkpoint._get_driver() + + +def test_get_driver_rejects_unsupported_driver_version(monkeypatch): + monkeypatch.setattr(checkpoint, "_driver", _MockDriver()) + monkeypatch.setattr(checkpoint, "_driver_capability_checked", False) + monkeypatch.setattr(checkpoint, "_binding_version", lambda: (13, 0, 2)) + monkeypatch.setattr(checkpoint, "_driver_version", lambda: (12, 7, 0)) + + with pytest.raises(RuntimeError, match="CUDA checkpointing is not supported"): + checkpoint._get_driver() + + +def test_checkpoint_apis_translate_missing_runtime_symbol(): + driver = _MockDriver() + + def missing_checkpoint_symbol(): + raise RuntimeError('Function "cuCheckpointProcessLock" not found') + + with pytest.raises(RuntimeError, match="CUDA checkpointing is not supported"): + checkpoint._call_driver(driver, missing_checkpoint_symbol)