Skip to content

Commit

Permalink
Let "logger.catch" be used as a context manager too
Browse files Browse the repository at this point in the history
  • Loading branch information
Delgan committed Nov 5, 2017
1 parent 6da8795 commit 041f209
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 132 deletions.
141 changes: 85 additions & 56 deletions loguru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,11 +578,84 @@ def stop(self):
self.file = None
self.file_path = None

class Catcher:

def __init__(self, logger, exception=BaseException, *, level=None, reraise=False,
message="An error has been caught in function '{function}', "
"process '{process.name}' ({process.id}), "
"thread '{thread.name}' ({thread.id}):"):
self.logger = logger
self.exception = exception
self.level = level
self.reraise = reraise
self.message = message

self.function_name = None
self.exception_logger = self.logger.exception

def __enter__(self):
pass

def __exit__(self, type_, value, traceback_):
if type_ is None:
return

if not issubclass(type_, self.exception):
return False

thread = current_thread()
thread_recattr = ThreadRecattr(thread.ident)
thread_recattr.id, thread_recattr.name = thread.ident, thread.name

process = current_process()
process_recattr = ProcessRecattr(process.ident)
process_recattr.id, process_recattr.name = process.ident, process.name

function_name = self.function_name
if function_name is None:
function_name = getframe(1).f_code.co_name

record = {
'process': process_recattr,
'thread': thread_recattr,
'function': function_name,
}

if self.level is not None:
# TODO: Use logger function accordingly
raise NotImplementedError

self.exception_logger(self.message.format_map(record))

return not self.reraise

def __call__(self, *args, **kwargs):
if not kwargs and len(args) == 1:
arg = args[0]
if callable(arg) and (not isclass(arg) or not issubclass(arg, BaseException)):
function = arg
function_name = function.__name__

@functools.wraps(function)
def catch_wrapper(*args, **kwargs):
# TODO: Check it could be any conflict with multiprocessing because of self modification
self.function_name = function_name
self.exception_logger = self.logger._exception_catcher
with self:
function(*args, **kwargs)
self.function_name = None
self.exception_logger = self.logger.exception

return catch_wrapper

return Catcher(self.logger, *args, **kwargs)

class Logger:

def __init__(self):
self.handlers_count = 0
self.handlers = {}
self.catch = Catcher(self)

def log_to(self, sink, *, level=DEBUG, format=VERBOSE_FORMAT, filter=None, colored=None, better_exceptions=True, **kwargs):
if isclass(sink):
Expand Down Expand Up @@ -656,54 +729,6 @@ def stop(self, handler_id=None):

return 0

def catch(self, *args, **kwargs):

def catch_decorator(wrapped_function,
exception=BaseException, *,
message="An error has been caught in function '{function}', "
"process '{process.name}' ({process.id}), "
"thread '{thread.name}' ({thread.id}):",
level=None, reraise=False):

if level is not None:
# TODO: Call log function accordingly
raise NotImplementedError

@functools.wraps(wrapped_function)
def catch_wrapper(*args, **kwargs):
try:
wrapped_function(*args, **kwargs)
except exception:
thread = current_thread()
thread_recattr = ThreadRecattr(thread.ident)
thread_recattr.id, thread_recattr.name = thread.ident, thread.name

process = current_process()
process_recattr = ProcessRecattr(process.ident)
process_recattr.id, process_recattr.name = process.ident, process.name

function_name = wrapped_function.__name__

record = {
'process': process_recattr,
'thread': thread_recattr,
'function': function_name,
}

self._exception_catcher(message.format_map(record))

if reraise:
raise

return catch_wrapper

if not kwargs and len(args) == 1:
arg = args[0]
if callable(arg) and (not isclass(arg) or not issubclass(arg, BaseException)):
return catch_decorator(arg)

return lambda f: catch_decorator(f, *args, **kwargs)

@staticmethod
def make_log_function(level, log_exception=0):

Expand Down Expand Up @@ -759,20 +784,24 @@ def log_function(self, message, *args, **kwargs):
loguru_tb = new_tb
tb = tb.tb_next

caught_tb_marked = False

if log_exception == 1:
root_tb.__is_caught_point__ = True
caught_tb_marked = True
caught_tb = root_tb

while root_frame:
if root_frame.f_code.co_filename != __file__:
root_tb = loguru_traceback(root_frame, root_frame.f_lasti, root_frame.f_lineno, root_tb)
if not caught_tb_marked:
root_tb.__is_caught_point__ = True
caught_tb_marked = True
root_frame = root_frame.f_back

if log_exception == 1:
caught_tb.__is_caught_point__ = True
else:
tb_prev = tb_next = root_tb
while tb_next:
if tb_next == caught_tb:
break
tb_prev, tb_next = tb_next, tb_next.tb_next
tb_prev.__is_caught_point__ = True


exception = (ex_type, ex, root_tb)

record = {
Expand Down
1 change: 0 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import loguru
import pytest
import sys

@pytest.fixture
def logger():
Expand Down
153 changes: 105 additions & 48 deletions tests/test_catch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,99 @@
import pytest
import traceback

@pytest.mark.parametrize('args, kwargs', [
([], {}),
([2, 0], {}),
([4], {'b': 0}),
([], {'a': 8}),
])
def test_wrapped(logger, writer, args, kwargs):
zero_division_error = 'ZeroDivisionError: division by zero\n'
use_decorator = pytest.mark.parametrize('use_decorator', [True, False])

@pytest.mark.parametrize('use_parentheses', [True, False])
def test_decorator(logger, writer, use_parentheses):
logger.log_to(writer)

@logger.catch
def c(a=1, b=0):
a / b
if use_parentheses:
@logger.catch()
def c(a, b):
a / b
c(5, b=0)
else:
@logger.catch
def c(a, b=0):
a / b
c(2)

c(*args, **kwargs)
assert writer.read().endswith(zero_division_error)

assert writer.read().endswith('ZeroDivisionError: division by zero\n')
@pytest.mark.parametrize('use_parentheses', [True, False])
def test_context_manager(logger, writer, use_parentheses):
logger.log_to(writer)

def test_wrapped_better_exceptions(logger, writer):
if use_parentheses:
with logger.catch():
1 / 0
else:
with logger.catch:
1 / 0

assert writer.read().endswith(zero_division_error)

def test_with_better_exceptions(logger, writer):
logger.log_to(writer, better_exceptions=True)

@logger.catch()
def c():
a = 2
b = 0
a / b

c()
decorated = logger.catch(c)
decorated()

result_with = writer.read().strip()
result_with = writer.read()

logger.stop()
writer.clear()

logger.log_to(writer, better_exceptions=False)

@logger.catch()
def c():
a = 2
b = 0
a / b
decorated = logger.catch(c)
decorated()

c()

result_without = writer.read().strip()
result_without = writer.read()

assert len(result_with) > len(result_without)
assert result_with.endswith(zero_division_error)
assert result_without.endswith(zero_division_error)

def test_custom_message(logger, writer):
@use_decorator
def test_custom_message(logger, writer, use_decorator):
logger.log_to(writer, format='{message}')
message = 'An error occured:'

@logger.catch(message='An error occured:')
def a():
1 / 0

a()
if use_decorator:
@logger.catch(message=message)
def a():
1 / 0
a()
else:
with logger.catch(message=message):
1 / 0

assert writer.read().startswith('An error occured:\n')
assert writer.read().startswith(message + '\n')

def test_reraise(logger, writer):
@use_decorator
def test_reraise(logger, writer, use_decorator):
logger.log_to(writer)

@logger.catch(reraise=True)
def a():
1 / 0
if use_decorator:
@logger.catch(reraise=True)
def a():
1 / 0
else:
def a():
with logger.catch(reraise=True):
1 / 0

with pytest.raises(ZeroDivisionError):
a()

assert writer.read().endswith('ZeroDivisionError: division by zero\n')
assert writer.read().endswith(zero_division_error)

@pytest.mark.parametrize('exception, should_raise', [
(ZeroDivisionError, False),
Expand All @@ -78,32 +102,65 @@ def a():
((SyntaxError, TypeError), True),
])
@pytest.mark.parametrize('keyword', [True, False])
def test_exception(logger, writer, exception, should_raise, keyword):
@use_decorator
def test_exception(logger, writer, exception, should_raise, keyword, use_decorator):
logger.log_to(writer)

if keyword:
@logger.catch(exception=exception)
def a():
1 / 0
if use_decorator:
@logger.catch(exception=exception)
def a():
1 / 0
else:
def a():
with logger.catch(exception=exception):
1 / 0
else:
@logger.catch(exception)
def a():
1 / 0
if use_decorator:
@logger.catch(exception)
def a():
1 / 0
else:
def a():
with logger.catch(exception):
1 / 0

if should_raise:
with pytest.raises(ZeroDivisionError):
a()
assert writer.read() == ''
else:
a()
assert writer.read().endswith('ZeroDivisionError: division by zero\n')
assert writer.read().endswith(zero_division_error)

@use_decorator
def test_not_raising(logger, writer, use_decorator):
logger.log_to(writer, format='{message}')
message = "It's ok"

if use_decorator:
@logger.catch
def a():
logger.debug(message)
a()
else:
with logger.catch:
logger.debug(message)

assert writer.read() == message + '\n'

@pytest.mark.xfail
def test_custom_level(logger, writter):
@use_decorator
def test_custom_level(logger, writter, use_decorator):
logger.log_to(writer)

@logger.catch(level=10)
def a():
1 / 0
if use_decorator:
@logger.catch(level=10)
def a():
1 / 0
else:
def a():
with logger.catch(level=10):
1 / 0

a()

0 comments on commit 041f209

Please sign in to comment.