diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 98a45d0043..149c92b8e1 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Optional from cuda.core.experimental._context import Context from cuda.core.experimental._utils.cuda_utils import ( CUDAError, + check_multiprocessing_start_method, driver, ) if TYPE_CHECKING: @@ -300,6 +301,7 @@ cdef class IPCEventDescriptor: def _reduce_event(event): + check_multiprocessing_start_method() return event.from_ipc_descriptor, (event.get_ipc_descriptor(),) multiprocessing.reduction.register(Event, _reduce_event) diff --git a/cuda_core/cuda/core/experimental/_memory/_ipc.pyx b/cuda_core/cuda/core/experimental/_memory/_ipc.pyx index 22be23d9ea..c9931855cf 100644 --- a/cuda_core/cuda/core/experimental/_memory/_ipc.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_ipc.pyx @@ -10,6 +10,7 @@ from cuda.bindings cimport cydriver from cuda.core.experimental._memory._buffer cimport Buffer from cuda.core.experimental._stream cimport default_stream from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN +from cuda.core.experimental._utils.cuda_utils import check_multiprocessing_start_method import multiprocessing import os @@ -129,6 +130,7 @@ cdef class IPCAllocationHandle: def _reduce_allocation_handle(alloc_handle): + check_multiprocessing_start_method() df = multiprocessing.reduction.DupFd(alloc_handle.handle) return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid) @@ -141,6 +143,7 @@ multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handl def _deep_reduce_device_memory_resource(mr): + check_multiprocessing_start_method() from .._device import Device device = Device(mr.device_id) alloc_handle = mr.get_allocation_handle() diff --git a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx index d57a777537..4489871747 100644 --- a/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx +++ b/cuda_core/cuda/core/experimental/_utils/cuda_utils.pyx @@ -5,6 +5,9 @@ import functools from functools import partial import importlib.metadata +import multiprocessing +import platform +import warnings from collections import namedtuple from collections.abc import Sequence from contextlib import ExitStack @@ -283,3 +286,48 @@ class Transaction: """ # pop_all() empties this stack so no callbacks are triggered on exit. self._stack.pop_all() + + +# Track whether we've already warned about fork method +_fork_warning_checked = False + + +def reset_fork_warning(): + """Reset the fork warning check flag for testing purposes. + + This function is intended for use in tests to allow multiple test runs + to check the warning behavior. + """ + global _fork_warning_checked + _fork_warning_checked = False + + +def check_multiprocessing_start_method(): + """Check if multiprocessing start method is 'fork' and warn if so.""" + global _fork_warning_checked + if _fork_warning_checked: + return + _fork_warning_checked = True + + # Common warning message parts + common_message = ( + "CUDA does not support. Forked subprocesses exhibit undefined behavior, " + "including failure to initialize CUDA contexts and devices. Set the start method " + "to 'spawn' before creating processes that use CUDA. " + "Use: multiprocessing.set_start_method('spawn')" + ) + + try: + start_method = multiprocessing.get_start_method() + if start_method == "fork": + message = f"multiprocessing start method is 'fork', which {common_message}" + warnings.warn(message, UserWarning, stacklevel=3) + except RuntimeError: + # get_start_method() can raise RuntimeError if start method hasn't been set + # In this case, default is 'fork' on Linux, so we should warn + if platform.system() == "Linux": + message = ( + f"multiprocessing start method is not set and defaults to 'fork' on Linux, " + f"which {common_message}" + ) + warnings.warn(message, UserWarning, stacklevel=3) diff --git a/cuda_core/tests/test_multiprocessing_warning.py b/cuda_core/tests/test_multiprocessing_warning.py new file mode 100644 index 0000000000..945ea83964 --- /dev/null +++ b/cuda_core/tests/test_multiprocessing_warning.py @@ -0,0 +1,149 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Test that warnings are emitted when multiprocessing start method is 'fork' +and IPC objects are serialized. + +These tests use mocking to simulate the 'fork' start method without actually +using fork, avoiding the need for subprocess isolation. +""" + +import warnings +from unittest.mock import patch + +from cuda.core.experimental import DeviceMemoryResource, DeviceMemoryResourceOptions, EventOptions +from cuda.core.experimental._event import _reduce_event +from cuda.core.experimental._memory._ipc import ( + _deep_reduce_device_memory_resource, + _reduce_allocation_handle, +) +from cuda.core.experimental._utils.cuda_utils import reset_fork_warning + + +def test_warn_on_fork_method_device_memory_resource(ipc_device): + """Test that warning is emitted when DeviceMemoryResource is pickled with fork method.""" + device = ipc_device + device.set_current() + options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True) + mr = DeviceMemoryResource(device, options=options) + + with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Reset the warning flag to allow testing + reset_fork_warning() + + # Trigger the reduction function directly + _deep_reduce_device_memory_resource(mr) + + # Check that warning was emitted + assert len(w) == 1, f"Expected 1 warning, got {len(w)}: {[str(warning.message) for warning in w]}" + warning = w[0] + assert warning.category is UserWarning + assert "fork" in str(warning.message).lower() + assert "spawn" in str(warning.message).lower() + assert "undefined behavior" in str(warning.message).lower() + + mr.close() + + +def test_warn_on_fork_method_allocation_handle(ipc_device): + """Test that warning is emitted when IPCAllocationHandle is pickled with fork method.""" + device = ipc_device + device.set_current() + options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True) + mr = DeviceMemoryResource(device, options=options) + alloc_handle = mr.get_allocation_handle() + + with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Reset the warning flag to allow testing + reset_fork_warning() + + # Trigger the reduction function directly + _reduce_allocation_handle(alloc_handle) + + # Check that warning was emitted + assert len(w) == 1 + warning = w[0] + assert warning.category is UserWarning + assert "fork" in str(warning.message).lower() + + mr.close() + + +def test_warn_on_fork_method_event(mempool_device): + """Test that warning is emitted when Event is pickled with fork method.""" + device = mempool_device + device.set_current() + stream = device.create_stream() + ipc_event_options = EventOptions(ipc_enabled=True) + event = stream.record(options=ipc_event_options) + + with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Reset the warning flag to allow testing + reset_fork_warning() + + # Trigger the reduction function directly + _reduce_event(event) + + # Check that warning was emitted + assert len(w) == 1 + warning = w[0] + assert warning.category is UserWarning + assert "fork" in str(warning.message).lower() + + event.close() + + +def test_no_warning_with_spawn_method(ipc_device): + """Test that no warning is emitted when start method is 'spawn'.""" + device = ipc_device + device.set_current() + options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True) + mr = DeviceMemoryResource(device, options=options) + + with patch("multiprocessing.get_start_method", return_value="spawn"), warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Reset the warning flag to allow testing + reset_fork_warning() + + # Trigger the reduction function directly + _deep_reduce_device_memory_resource(mr) + + # Check that no fork-related warning was emitted + fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()] + assert len(fork_warnings) == 0, f"Unexpected warning: {fork_warnings[0].message if fork_warnings else None}" + + mr.close() + + +def test_warning_emitted_only_once(ipc_device): + """Test that warning is only emitted once even when multiple objects are pickled.""" + device = ipc_device + device.set_current() + options = DeviceMemoryResourceOptions(max_size=2097152, ipc_enabled=True) + mr1 = DeviceMemoryResource(device, options=options) + mr2 = DeviceMemoryResource(device, options=options) + + with patch("multiprocessing.get_start_method", return_value="fork"), warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Reset the warning flag to allow testing + reset_fork_warning() + + # Trigger reduction multiple times + _deep_reduce_device_memory_resource(mr1) + _deep_reduce_device_memory_resource(mr2) + + # Check that warning was emitted only once + fork_warnings = [warning for warning in w if "fork" in str(warning.message).lower()] + assert len(fork_warnings) == 1, f"Expected 1 warning, got {len(fork_warnings)}" + + mr1.close() + mr2.close()