Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions cuda_bindings/tests/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,27 @@ def test_cuCheckpointProcessGetState_failure():
assert state is None


@pytest.mark.skipif(not supportsCudaAPI("cuCheckpointProcessGetState"), reason="When API was introduced")
def test_cuCheckpoint_required_bindings_present():
required_bindings = (
"cuCheckpointProcessCheckpoint",
"cuCheckpointProcessGetRestoreThreadId",
"cuCheckpointProcessGetState",
"cuCheckpointProcessLock",
"cuCheckpointProcessRestore",
"cuCheckpointProcessUnlock",
"CUcheckpointLockArgs",
"CUprocessState",
"CUcheckpointRestoreArgs",
)
if cuda.CUDA_VERSION >= 13000:
required_bindings += ("CUcheckpointGpuPair",)

missing = [name for name in required_bindings if not hasattr(cuda, name)]

assert missing == []


def test_private_function_pointer_inspector():
from cuda.bindings._bindings.cydriver import _inspect_function_pointer

Expand Down
18 changes: 14 additions & 4 deletions cuda_core/cuda/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@
"cuCheckpointProcessLock",
"cuCheckpointProcessRestore",
"cuCheckpointProcessUnlock",
"CUcheckpointGpuPair",
"CUcheckpointLockArgs",
"CUprocessState",
"CUcheckpointRestoreArgs",
)
_GPU_MAPPING_BINDING_ATTRS = ("CUcheckpointGpuPair",)
_REQUIRED_DRIVER_VERSION = (12, 8, 0)
_driver_capability_checked = False

Expand Down Expand Up @@ -215,23 +215,33 @@ def _make_restore_args(driver, gpu_mapping: _Mapping[_Any, _Any] | None):
if not isinstance(gpu_mapping, _Mapping):
raise TypeError("gpu_mapping must be a mapping from checkpointed GPU UUID to restore GPU UUID")

if not gpu_mapping:
return None

pairs = []
_require_gpu_mapping_bindings(driver)
for old_uuid, new_uuid in gpu_mapping.items():
pair = driver.CUcheckpointGpuPair()
buffers = []
pair.oldUuid = _as_cuuuid(driver, old_uuid, buffers)
pair.newUuid = _as_cuuuid(driver, new_uuid, buffers)
pairs.append(pair)

if not pairs:
return None

args = driver.CUcheckpointRestoreArgs()
args.gpuPairs = pairs
args.gpuPairsCount = len(pairs)
return args


def _require_gpu_mapping_bindings(driver) -> None:
missing = [name for name in _GPU_MAPPING_BINDING_ATTRS if not hasattr(driver, name)]
if missing:
raise RuntimeError(
"CUDA checkpoint GPU remapping requires cuda.bindings with GPU remapping support. "
f"Missing: {', '.join(missing)}"
)


def _as_cuuuid(driver, value, buffers):
"""Convert *value* to a ``CUuuid``.

Expand Down
77 changes: 75 additions & 2 deletions cuda_core/tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,48 @@ def _checkpoint_available():
try:
checkpoint._get_driver()
return True
except RuntimeError:
except RuntimeError as exc:
if _checkpoint_unavailable_can_skip(str(exc)):
return False
raise


def _checkpoint_gpu_mapping_available():
"""Return True if checkpoint restore GPU remapping is usable on this system."""
if not _checkpoint_available():
return False
try:
checkpoint._require_gpu_mapping_bindings(checkpoint._get_driver())
return True
except RuntimeError as exc:
if _checkpoint_gpu_mapping_unavailable_can_skip(str(exc)):
return False
raise


def _checkpoint_unavailable_can_skip(message):
return message.startswith(
(
"CUDA checkpointing is not supported by the installed NVIDIA driver.",
"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. Found cuda.bindings ",
)
)


def _checkpoint_gpu_mapping_unavailable_can_skip(message):
return message.startswith(
"CUDA checkpoint GPU remapping requires cuda.bindings with GPU remapping support. Missing: CUcheckpointGpuPair"
)


needs_checkpoint = pytest.mark.skipif(
sys.platform != "linux" or not _checkpoint_available(),
reason="CUDA checkpoint API requires Linux and a supported driver/bindings",
)
needs_checkpoint_gpu_mapping = pytest.mark.skipif(
sys.platform != "linux" or not _checkpoint_gpu_mapping_available(),
reason="CUDA checkpoint GPU remapping requires Linux and supported driver/bindings",
)


# -- Helpers ---------------------------------------------------------------
Expand Down Expand Up @@ -384,6 +418,45 @@ def test_public_symbols(self):
assert checkpoint.__all__ == ["Process"]
assert not hasattr(checkpoint, "ProcessStateType")

def test_checkpoint_available_skips_unsupported_driver(self, monkeypatch):
def raise_unsupported_driver():
raise RuntimeError("CUDA checkpointing is not supported by the installed NVIDIA driver.")

monkeypatch.setattr(checkpoint, "_get_driver", raise_unsupported_driver)

assert not _checkpoint_available()

def test_checkpoint_available_skips_old_bindings(self, monkeypatch):
def raise_old_bindings():
raise RuntimeError(
"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. "
"Found cuda.bindings 12.7.0."
)

monkeypatch.setattr(checkpoint, "_get_driver", raise_old_bindings)

assert not _checkpoint_available()

def test_checkpoint_available_fails_missing_required_bindings(self, monkeypatch):
def raise_missing_binding():
raise RuntimeError(
"CUDA checkpointing requires cuda.bindings with CUDA checkpoint API support. "
"Missing: CUcheckpointRestoreArgs"
)

monkeypatch.setattr(checkpoint, "_get_driver", raise_missing_binding)

with pytest.raises(RuntimeError, match="Missing: CUcheckpointRestoreArgs"):
_checkpoint_available()

def test_checkpoint_gpu_mapping_available_skips_missing_gpu_pair(self, monkeypatch):
class Driver:
pass

monkeypatch.setattr(checkpoint, "_get_driver", lambda: Driver)

assert not _checkpoint_gpu_mapping_available()

def test_pid_is_read_only(self):
proc = checkpoint.Process(1)
assert proc.pid == 1
Expand Down Expand Up @@ -420,7 +493,7 @@ def test_full_cycle_no_migration(self):
# -- GPU migration (>= 2 same-chip GPUs, real driver) ---------------------


@needs_checkpoint
@needs_checkpoint_gpu_mapping
class TestCheckpointGpuMigration:
"""GPU UUID remapping tests following the r580-migration-api.c pattern.

Expand Down
Loading