Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: logger inheritance #99

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_lambda_powertools/logging/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(self, **kwargs):
self.format_dict.update(kwargs)
self.default_json_formatter = kwargs.pop("json_default", json_formatter)

def update_formatter(self, **kwargs):
self.format_dict.update(kwargs)

def format(self, record): # noqa: A003
record_dict = record.__dict__.copy()
record_dict["asctime"] = self.formatTime(record, self.datefmt)
Expand Down
91 changes: 57 additions & 34 deletions aws_lambda_powertools/logging/logger.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import functools
import logging
import os
Expand Down Expand Up @@ -34,7 +33,7 @@ def _is_cold_start() -> bool:
return cold_start


class Logger(logging.Logger):
class Logger:
"""Creates and setups a logger to format statements in JSON.

Includes service name and any additional key=value into logs
Expand All @@ -55,6 +54,8 @@ class Logger(logging.Logger):
service name to be appended in logs, by default "service_undefined"
level : str, optional
logging.level, by default "INFO"
name: str
Logger name, "{service}" by default
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
sample_rate: float, optional
sample rate for debug calls within execution context defaults to 0.0
stream: sys.stdout, optional
Expand All @@ -80,7 +81,7 @@ class Logger(logging.Logger):
>>> def handler(event, context):
logger.info("Hello")

**Append payment_id to previously setup structured log logger**
**Append payment_id to previously setup logger**

>>> from aws_lambda_powertools import Logger
>>> logger = Logger(service="payment")
Expand All @@ -89,18 +90,15 @@ class Logger(logging.Logger):
logger.structure_logs(append=True, payment_id=event["payment_id"])
logger.info("Hello")

Parameters
----------
logging : logging.Logger
Inherits Logger
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
service: str
name of the service to create the logger for, "service_undefined" by default
level: str, int
log level, INFO by default
sampling_rate: float
debug log sampling rate, 0.0 by default
stream: sys.stdout
log stream, stdout by default
**Create child Logger using logging inheritance via name param**

>>> # app.py
>>> from aws_lambda_powertools import Logger
>>> logger = Logger(name="payment")
>>>
>>> # another_file.py
>>> from aws_lambda_powertools import Logger
>>> logger = Logger(name="payment.child)

Raises
------
Expand All @@ -112,19 +110,48 @@ def __init__(
self,
service: str = None,
level: Union[str, int] = None,
name: str = None,
sampling_rate: float = None,
stream: sys.stdout = None,
**kwargs,
):
self.service = service or os.getenv("POWERTOOLS_SERVICE_NAME") or "service_undefined"
self.name = name or self.service
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
self.sampling_rate = sampling_rate or os.getenv("POWERTOOLS_LOGGER_SAMPLE_RATE") or 0.0
self.log_level = level or os.getenv("LOG_LEVEL") or logging.INFO
self.handler = logging.StreamHandler(stream) if stream is not None else logging.StreamHandler(sys.stdout)
self.log_level = level or os.getenv("LOG_LEVEL".upper()) or logging.INFO
heitorlessa marked this conversation as resolved.
Show resolved Hide resolved
self._handler = logging.StreamHandler(stream) if stream is not None else logging.StreamHandler(sys.stdout)
self._default_log_keys = {"service": self.service, "sampling_rate": self.sampling_rate}
self.log_keys = copy.copy(self._default_log_keys)

super().__init__(name=self.service, level=self.log_level)

self._logger = logging.getLogger(self.name)

self._init_logger(**kwargs)

def __getattr__(self, name):
# Proxy attributes not found to actual logger to support backward compatibility
# https://github.com/awslabs/aws-lambda-powertools-python/issues/97
return getattr(self._logger, name)

def _init_logger(self, **kwargs):
"""Configures new logger"""
# Ensure logger children remains independent
self._logger.propagate = False

# Skip configuration if it's a child logger, or a handler is already present
# to prevent multiple Loggers with the same name or its children having different sampling mechanisms
# and multiple messages from being logged as handlers can be duplicated
if self._logger.parent.name == "root" and not self._logger.handlers:
self._configure_sampling()
self._logger.setLevel(self.log_level)
self._logger.addHandler(self._handler)
self.structure_logs(**kwargs)

def _configure_sampling(self):
"""Dynamically set log level based on sampling rate

Raises
------
InvalidLoggerSamplingRateError
When sampling rate provided is not a float
"""
try:
if self.sampling_rate and random.random() <= float(self.sampling_rate):
logger.debug("Setting log level to Debug due to sampling rate")
Expand All @@ -134,12 +161,8 @@ def __init__(
f"Expected a float value ranging 0 to 1, but received {self.sampling_rate} instead. Please review POWERTOOLS_LOGGER_SAMPLE_RATE environment variable." # noqa E501
)

self.setLevel(self.log_level)
self.structure_logs(**kwargs)
self.addHandler(self.handler)

def inject_lambda_context(self, lambda_handler: Callable[[Dict, Any], Any] = None, log_event: bool = False):
"""Decorator to capture Lambda contextual info and inject into struct logging
"""Decorator to capture Lambda contextual info and inject into logger

Parameters
----------
Expand Down Expand Up @@ -216,21 +239,21 @@ def structure_logs(self, append: bool = False, **kwargs):
append : bool, optional
[description], by default False
"""
self.handler.setFormatter(JsonFormatter(**self._default_log_keys, **kwargs))

if append:
new_keys = {**self.log_keys, **kwargs}
self.handler.setFormatter(JsonFormatter(**new_keys))

self.log_keys.update(**kwargs)
for handler in self._logger.handlers:
if append:
# Update existing formatter in an existing logger handler
handler.formatter.update_formatter(**kwargs)
else:
# Set a new formatter for a logger handler
handler.setFormatter(JsonFormatter(**self._default_log_keys, **kwargs))


def set_package_logger(
level: Union[str, int] = logging.DEBUG, stream: sys.stdout = None, formatter: logging.Formatter = None
):
"""Set an additional stream handler, formatter, and log level for aws_lambda_powertools package logger.

**Package log by default is supressed (NullHandler), this should only used for debugging.
**Package log by default is suppressed (NullHandler), this should only used for debugging.
This is separate from application Logger class utility**

Example
Expand Down
48 changes: 13 additions & 35 deletions tests/functional/test_aws_lambda_logging.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,20 @@
"""aws_lambda_logging tests."""
import io
import json
import logging

import pytest

from aws_lambda_powertools import Logger
from .utility_functions import get_random_logger


@pytest.fixture
def stdout():
return io.StringIO()


@pytest.fixture
def handler(stdout):
return logging.StreamHandler(stdout)


@pytest.fixture
def logger():
return logging.getLogger(__name__)


@pytest.fixture
def root_logger(handler):
logging.root.addHandler(handler)
yield logging.root
logging.root.removeHandler(handler)


@pytest.mark.parametrize("level", ["DEBUG", "WARNING", "ERROR", "INFO", "CRITICAL"])
def test_setup_with_valid_log_levels(root_logger, stdout, level):
logger = Logger(level=level, stream=stdout, request_id="request id!", another="value")
def test_setup_with_valid_log_levels(stdout, level):
logger = get_random_logger(level=level, stream=stdout, request_id="request id!", another="value")
msg = "This is a test"
log_command = {
"INFO": logger.info,
Expand All @@ -55,8 +37,8 @@ def test_setup_with_valid_log_levels(root_logger, stdout, level):
assert "exception" not in log_dict


def test_logging_exception_traceback(root_logger, stdout):
logger = Logger(level="DEBUG", stream=stdout, request_id="request id!", another="value")
def test_logging_exception_traceback(stdout):
logger = get_random_logger(level="DEBUG", stream=stdout, request_id="request id!", another="value")

try:
raise Exception("Boom")
Expand All @@ -69,9 +51,9 @@ def test_logging_exception_traceback(root_logger, stdout):
assert "exception" in log_dict


def test_setup_with_invalid_log_level(root_logger, logger, stdout):
def test_setup_with_invalid_log_level(stdout):
with pytest.raises(ValueError) as e:
Logger(level="not a valid log level")
get_random_logger(level="not a valid log level")
assert "Unknown level" in e.value.args[0]


Expand All @@ -82,12 +64,8 @@ def check_log_dict(log_dict):
assert "message" in log_dict


def test_setup_with_bad_level_does_not_fail():
Logger("DBGG", request_id="request id!", another="value")


def test_with_dict_message(root_logger, stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_dict_message(stdout):
logger = get_random_logger(level="DEBUG", stream=stdout)

msg = {"x": "isx"}
logger.critical(msg)
Expand All @@ -97,8 +75,8 @@ def test_with_dict_message(root_logger, stdout):
assert msg == log_dict["message"]


def test_with_json_message(root_logger, stdout):
logger = Logger(stream=stdout)
def test_with_json_message(stdout):
logger = get_random_logger(stream=stdout)

msg = {"x": "isx"}
logger.info(json.dumps(msg))
Expand All @@ -108,8 +86,8 @@ def test_with_json_message(root_logger, stdout):
assert msg == log_dict["message"]


def test_with_unserialisable_value_in_message(root_logger, stdout):
logger = Logger(level="DEBUG", stream=stdout)
def test_with_unserialisable_value_in_message(stdout):
logger = get_random_logger(level="DEBUG", stream=stdout)

class X:
pass
Expand Down
Loading