Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion docs/howto.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,27 @@ with add_global_context({"user_id": user_id, "request_id": request_id}):
```


## Using the automatic init

If you don't want to customise the initialization you can let `add_global_context` automatically handle the init and shutdown for you:

```python
import logging

from logging_with_context.global_context import add_global_context


def main():
logging.basicConfig(level=logging.INFO) # Or any other way to setup logging.
with add_global_context({"user_id": 10}):
# Here the context is automatically initialized.
# It'll also be automatically shutdown once this context manager finishes.
```


## Using the init/shutdown API

In case you can't use the context manager, you can use the manual initialization and shutdown API:
In case you want to customise the initialization but can't use the context manager, you can use the manual initialization and shutdown API:

```python
import logging
Expand Down
2 changes: 1 addition & 1 deletion run_tests.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash

for python in 3.9 3.10 3.11 3.12 3.13; do
uv run --locked --isolated --python=$python pytest
uv run --frozen --isolated --python=$python pytest
done
33 changes: 28 additions & 5 deletions src/logging_with_context/global_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
__global_context_var: ContextVar[dict[str, Any]] = ContextVar(
"global_context", default={}
)
__global_context_initialized: ContextVar[bool] = ContextVar(
"global_context_initialized", default=False
)


def _get_loggers_to_process(loggers: Optional[Sequence[Logger]] = None) -> list[Logger]:
return [getLogger()] if loggers is None else list(loggers)


def init_global_context(loggers: Optional[Sequence[Logger]] = None) -> None:
Expand All @@ -26,9 +33,9 @@ def init_global_context(loggers: Optional[Sequence[Logger]] = None) -> None:
loggers: The loggers to attach the global context; if not loggers are specified
it will use the root logger.
"""
loggers_to_process = [getLogger()] if loggers is None else list(loggers)
__global_context_initialized.set(True)
filter_with_context = FilterWithContextVar(__global_context_var)
for logger in loggers_to_process:
for logger in _get_loggers_to_process(loggers):
for handler in logger.handlers:
handler.addFilter(filter_with_context)

Expand All @@ -41,8 +48,8 @@ def shutdown_global_context(loggers: Optional[Sequence[Logger]] = None) -> None:
loggers: The loggers that were used when calling `init_global_context`; by
default the root logger.
"""
loggers_to_process = [getLogger()] if loggers is None else list(loggers)
for logger in loggers_to_process:
__global_context_initialized.set(False)
for logger in _get_loggers_to_process(loggers):
for handler in logger.handlers:
for filter_ in handler.filters:
if not isinstance(filter_, FilterWithContextVar):
Expand Down Expand Up @@ -72,20 +79,36 @@ def global_context_initialized(


@contextmanager
def add_global_context(context: dict[str, Any]) -> Generator[None, None, None]:
def add_global_context(
context: dict[str, Any], *, auto_init: bool = True
) -> Generator[None, None, None]:
"""
Add values to the global context to be attached to all the log messages.

The values will be removed from the global context once the context manager exists.

Parameters:
context: A key/value mapping with the values to add to the global context.
auto_init: Indicate if the global context should be automatically initialized
if it isn't.

If `True`, the context will be also automatically shutdown before exiting.

If the global context is already initialized it'll do nothing.

Keyword-only argument.

Returns:
A context manager that manages the life of the values.
"""
auto_initialized = False
if not __global_context_initialized.get() and auto_init:
init_global_context()
auto_initialized = True
token = __global_context_var.set(__global_context_var.get() | context)
try:
yield
finally:
__global_context_var.reset(token)
if auto_initialized:
shutdown_global_context()
50 changes: 48 additions & 2 deletions tests/logging_with_context/test_global_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import threading

import pytest

Expand All @@ -23,7 +24,7 @@ def test_add_global_context_ok(caplog: pytest.LogCaptureFixture):

def test_add_global_context_without_init_ignored_ok(caplog: pytest.LogCaptureFixture):
logger = logging.getLogger(__name__)
with add_global_context({"key": "value"}):
with add_global_context({"key": "value"}, auto_init=False):
with caplog.at_level(logging.INFO):
logger.info("Test message")
assert len(caplog.records) == 1
Expand All @@ -35,8 +36,53 @@ def test_add_global_context_after_shutdown_ignored_ok(caplog: pytest.LogCaptureF
logger = logging.getLogger(__name__)
with global_context_initialized():
pass
with add_global_context({"key": "value"}), caplog.at_level(logging.INFO):
with (
add_global_context({"key": "value"}, auto_init=False),
caplog.at_level(logging.INFO),
):
logger.info("Test message")
assert len(caplog.records) == 1
result = caplog.records[0]
assert not hasattr(result, "key")


def test_add_global_context_auto_init_ok(caplog: pytest.LogCaptureFixture):
logger = logging.getLogger(__name__)
with add_global_context({"key": "value"}), caplog.at_level(logging.INFO):
logger.info("Test message")
assert len(caplog.records) == 1
result = caplog.records[0]
assert result.key == "value" # type: ignore


def test_add_global_context_multithread(caplog: pytest.LogCaptureFixture):
def worker(value: int) -> None:
with add_global_context({"value": value}):
logger.info("Message 1 from thread %s", value)
with add_global_context({"value": value * 10}):
logger.info("Message 2 from thread %s", value)
with add_global_context({"value": value * 100}):
logger.info("Message 3 from thread %s", value)

logger = logging.getLogger(__name__)
with global_context_initialized(), caplog.at_level(logging.INFO):
worker_1 = threading.Thread(target=worker, args=(1,))
worker_2 = threading.Thread(target=worker, args=(2,))
worker_1.start()
worker_2.start()
worker_1.join()
worker_2.join()
assert len(caplog.records) == 6
result = [
{"message": record.message, "value": record.value} # type: ignore
for record in sorted(caplog.records, key=lambda r: r.value) # type: ignore
]
expected = [
{"message": "Message 1 from thread 1", "value": 1},
{"message": "Message 1 from thread 2", "value": 2},
{"message": "Message 2 from thread 1", "value": 10},
{"message": "Message 2 from thread 2", "value": 20},
{"message": "Message 3 from thread 1", "value": 100},
{"message": "Message 3 from thread 2", "value": 200},
]
assert result == expected