Skip to content

Commit

Permalink
Make logged exceptions displaying the whole traceback
Browse files Browse the repository at this point in the history
  • Loading branch information
Delgan committed Nov 3, 2017
1 parent b1ce326 commit d4970e8
Show file tree
Hide file tree
Showing 5 changed files with 531 additions and 28 deletions.
104 changes: 94 additions & 10 deletions loguru/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sys import exc_info, stdout as STDOUT, stderr as STDERR
from multiprocessing import current_process
from threading import current_thread
from traceback import format_exception
import traceback
from numbers import Number
import shutil
import re
Expand All @@ -18,6 +18,7 @@
from string import Formatter
import math
import functools
import uuid

import ansimarkup
from better_exceptions_fork import ExceptionFormatter
Expand Down Expand Up @@ -78,9 +79,32 @@ def patch_datetime_file(date):
date._FORMATTER = 'alternative'
date._to_string_format = '%Y-%m-%d_%H-%M-%S'

class loguru_traceback:
__slots__ = ('tb_frame', 'tb_lasti', 'tb_lineno', 'tb_next', '__is_caught_point__')

def __init__(self, frame, lasti, lineno, next_=None, is_caught_point=False):
self.tb_frame = frame
self.tb_lasti = lasti
self.tb_lineno = lineno
self.tb_next = next_
self.__is_caught_point__ = is_caught_point


class StrRecord(str):
pass

class HackyInt(int):

rand = str(uuid.uuid4().int) + str(uuid.uuid4().int) # 32 bytes

def __str__(self):
self.true_value = int(repr(self))
self.false_value = '0' + repr(self) + self.rand
return self.false_value

def __eq__(self, other):
return False

class Handler:

def __init__(self, *, writter, level, format_, filter_, colored, better_exceptions):
Expand Down Expand Up @@ -118,11 +142,39 @@ def emit(self, record):
exception = record['exception']

formatted = self.formats_per_level[level.name].format_map(record) + '\n'

if exception:
hacked = None
tb = exception[2]
while tb:
if tb.__is_caught_point__:
hacked = HackyInt(tb.tb_lineno)
tb.tb_lineno = hacked
break
tb = tb.tb_next

if self.better_exceptions:
formatted += ''.join(self.exception_formatter.format_exception(*exception))
formatted_exception = self.exception_formatter.format_exception(*exception)
else:
formatted += ''.join(format_exception(*exception))
formatted_exception = traceback.format_exception(*exception)

formatted_exception = ''.join(formatted_exception)

tb_reg = r'Traceback \(most recent call last\):'
ansi_reg = r'[a-zA-Z0-9;\\\[]*'
hacky_reg = r'^({ansi})({tb})({ansi})$((?:(?!^{ansi}{tb}{ansi}$)[\s\S])*)^({ansi})( )({ansi}File.*line{ansi} {ansi})({line})({ansi},.*)$'.format(tb=tb_reg, ansi=ansi_reg, line=str(hacked.false_value))

def mark_catch_point(match):
m_1, tb, m_2, m_3, m_4, s, m_5, line, m_6 = match.groups()
tb = 'Traceback (most recent call last, catch point marked):'
s = '> '
line = str(hacked.true_value)
return ''.join([m_1, tb, m_2, m_3, m_4, s, m_5, line, m_6])

formatted_exception = re.sub(hacky_reg, mark_catch_point, formatted_exception, count=1, flags=re.M)

formatted += formatted_exception


message = StrRecord(formatted)
message.record = record
Expand Down Expand Up @@ -618,7 +670,7 @@ def catch_decorator(wrapped_function,
def catch_wrapper(*args, **kwargs):
try:
wrapped_function(*args, **kwargs)
except exception as e:
except exception:
thread = current_thread()
thread_recattr = ThreadRecattr(thread.ident)
thread_recattr.id, thread_recattr.name = thread.ident, thread.name
Expand All @@ -635,8 +687,7 @@ def catch_wrapper(*args, **kwargs):
'function': function_name,
}

# TODO: Use the stacktrace from 'e' rather than calling sys.exc_info() in
self.exception(message.format_map(record))
self._exception_catcher(message.format_map(record))

if reraise:
raise
Expand All @@ -651,7 +702,7 @@ def catch_wrapper(*args, **kwargs):
return lambda f: catch_decorator(f, *args, **kwargs)

@staticmethod
def make_log_function(level, log_exception=False):
def make_log_function(level, log_exception=0):

level_name = getLevelName(level)

Expand Down Expand Up @@ -686,7 +737,40 @@ def log_function(self, message, *args, **kwargs):
process_recattr = ProcessRecattr(process.ident)
process_recattr.id, process_recattr.name = process.ident, process.name

exception = exc_info() if log_exception else None
exception = None
if log_exception:
ex_type, ex, tb = exc_info()

root_frame = tb.tb_frame.f_back

# TODO: Test edge cases (look in CPython source code for traceback objects and exc.__traceback__ usages)

loguru_tb = root_tb = None
while tb:
if tb.tb_frame.f_code.co_filename != __file__:
new_tb = loguru_traceback(tb.tb_frame, tb.tb_lasti, tb.tb_lineno, None)
if loguru_tb:
loguru_tb.tb_next = new_tb
else:
root_tb = new_tb
loguru_tb = new_tb
tb = tb.tb_next

caught_tb_marked = False

if log_exception == 1:
root_tb.__is_caught_point__ = True
caught_tb_marked = True

while root_frame:
if root_frame.f_code.co_filename != __file__:
root_tb = loguru_traceback(root_frame, root_frame.f_lasti, root_frame.f_lineno, root_tb)
if not caught_tb_marked:
root_tb.__is_caught_point__ = True
caught_tb_marked = True
root_frame = root_frame.f_back

exception = (ex_type, ex, root_tb)

record = {
'name': name,
Expand Down Expand Up @@ -719,9 +803,9 @@ def log_function(self, message, *args, **kwargs):
success = make_log_function.__func__(SUCCESS)
warning = make_log_function.__func__(WARNING)
error = make_log_function.__func__(ERROR)
exception = make_log_function.__func__(ERROR, True)
exception = make_log_function.__func__(ERROR, 1)
_exception_catcher = make_log_function.__func__(ERROR, 2)
critical = make_log_function.__func__(CRITICAL)


logger = Logger()
logger.log_to(STDERR)
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ def w(message):

w.written = []
w.read = lambda: ''.join(w.written)
w.clear = lambda: w.written.clear()

return w
31 changes: 20 additions & 11 deletions tests/test_catch_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import traceback

@pytest.mark.parametrize('args, kwargs', [
([], {}),
Expand All @@ -17,9 +18,8 @@ def c(a=1, b=0):

assert writer.read().endswith('ZeroDivisionError: division by zero\n')

@pytest.mark.parametrize('better_exceptions', [True, False])
def test_wrapped_better_exceptions(logger, writer, better_exceptions):
logger.log_to(writer, better_exceptions=better_exceptions)
def test_wrapped_better_exceptions(logger, writer):
logger.log_to(writer, better_exceptions=True)

@logger.catch()
def c():
Expand All @@ -29,12 +29,24 @@ def c():

c()

length = len(writer.read().splitlines())
result_with = writer.read().strip()

if better_exceptions:
assert length == 15
else:
assert length == 7
logger.stop()
writer.clear()

logger.log_to(writer, better_exceptions=False)

@logger.catch()
def c():
a = 2
b = 0
a / b

c()

result_without = writer.read().strip()

assert len(result_with) > len(result_without)

def test_custom_message(logger, writer):
logger.log_to(writer, format='{message}')
Expand Down Expand Up @@ -86,9 +98,6 @@ def a():
a()
assert writer.read().endswith('ZeroDivisionError: division by zero\n')

def test_frame(logger, writer):
pass

@pytest.mark.xfail
def test_custom_level(logger, writter):
logger.log_to(writer)
Expand Down

0 comments on commit d4970e8

Please sign in to comment.