Skip to content

Commit

Permalink
Fix error using "set_start_method()" after "logger" import (#974)
Browse files Browse the repository at this point in the history
Calling "multiprocessing.get_context(method=None)" had the unexpected
side effect of also fixing the global start method (which can't be
changed afterwards).
  • Loading branch information
Delgan committed Sep 11, 2023
1 parent 14fa062 commit 086126f
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
=============

- Add support for formatting of ``ExceptionGroup`` errors (`#805 <https://github.com/Delgan/loguru/issues/805>`_).
- Fix possible ``RuntimeError`` when using ``multiprocessing.set_start_method()`` after importing the ``logger`` (`#974 <https://github.com/Delgan/loguru/issues/974>`_)
- Fix formatting of possible ``__notes__`` attached to an ``Exception`` (`#980 <https://github.com/Delgan/loguru/issues/980>`_).


Expand Down
12 changes: 9 additions & 3 deletions loguru/_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import json
import multiprocessing
import os
import threading
from contextlib import contextmanager
Expand Down Expand Up @@ -88,10 +89,15 @@ def __init__(
self._decolorized_format = self._formatter.strip()

if self._enqueue:
self._queue = self._multiprocessing_context.SimpleQueue()
if self._multiprocessing_context is None:
self._queue = multiprocessing.SimpleQueue()
self._confirmation_event = multiprocessing.Event()
self._confirmation_lock = multiprocessing.Lock()
else:
self._queue = self._multiprocessing_context.SimpleQueue()
self._confirmation_event = self._multiprocessing_context.Event()
self._confirmation_lock = self._multiprocessing_context.Lock()
self._queue_lock = create_handler_lock()
self._confirmation_event = self._multiprocessing_context.Event()
self._confirmation_lock = self._multiprocessing_context.Lock()
self._owner_process_pid = os.getpid()
self._thread = Thread(
target=self._queued_writer, daemon=True, name="loguru-writer-%d" % self._id
Expand Down
4 changes: 2 additions & 2 deletions loguru/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,9 +967,9 @@ def add(
if not isinstance(encoding, str):
encoding = "ascii"

if context is None or isinstance(context, str):
if isinstance(context, str):
context = get_context(context)
elif not isinstance(context, BaseContext):
elif context is not None and not isinstance(context, BaseContext):
raise TypeError(
"Invalid context, it should be a string or a multiprocessing context, "
"not: '%s'" % type(context).__name__
Expand Down
52 changes: 30 additions & 22 deletions tests/test_add_option_context.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,63 @@
import multiprocessing
import os
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest

from loguru import logger


def get_handler_context():
# No better way to test correct value than to access the private attribute.
handler = next(iter(logger._core.handlers.values()))
return handler._multiprocessing_context
@pytest.fixture
def reset_start_method():
yield
multiprocessing.set_start_method(None, force=True)


def test_default_context():
logger.add(lambda _: None, context=None)
assert get_handler_context() == multiprocessing.get_context(None)
@pytest.mark.usefixtures("reset_start_method")
def test_using_multiprocessing_directly_if_context_is_none():
logger.add(lambda _: None, enqueue=True, context=None)
assert multiprocessing.get_start_method(allow_none=True) is not None


@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
def test_fork_context_as_string(context_name):
logger.add(lambda _: None, context=context_name)
assert get_handler_context() == multiprocessing.get_context(context_name)
context = multiprocessing.get_context(context_name)
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
logger.add(lambda _: None, context=context_name, enqueue=True)
assert mock.called
assert multiprocessing.get_start_method(allow_none=True) is None


def test_spawn_context_as_string():
logger.add(lambda _: None, context="spawn")
assert get_handler_context() == multiprocessing.get_context("spawn")
context = multiprocessing.get_context("spawn")
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
logger.add(lambda _: None, context="spawn", enqueue=True)
assert mock.called
assert multiprocessing.get_start_method(allow_none=True) is None


@pytest.mark.skipif(os.name == "nt", reason="Windows does not support forking")
@pytest.mark.parametrize("context_name", ["fork", "forkserver"])
def test_fork_context_as_object(context_name):
context = multiprocessing.get_context(context_name)
logger.add(lambda _: None, context=context)
assert get_handler_context() == context
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
logger.add(lambda _: None, context=context, enqueue=True)
assert mock.called
assert multiprocessing.get_start_method(allow_none=True) is None


def test_spawn_context_as_object():
context = multiprocessing.get_context("spawn")
logger.add(lambda _: None, context=context)
assert get_handler_context() == context
with patch.object(type(context), "Lock", wraps=context.Lock) as mock:
logger.add(lambda _: None, context=context, enqueue=True)
assert mock.called
assert multiprocessing.get_start_method(allow_none=True) is None


def test_context_effectively_used():
default_context = multiprocessing.get_context()
mocked_context = MagicMock(spec=default_context, wraps=default_context)
logger.add(lambda _: None, context=mocked_context, enqueue=True)
logger.complete()
assert mocked_context.Lock.called
def test_global_start_method_is_none_if_enqueue_is_false():
logger.add(lambda _: None, enqueue=False, context=None)
assert multiprocessing.get_start_method(allow_none=True) is None


def test_invalid_context_name():
Expand Down

0 comments on commit 086126f

Please sign in to comment.