From 6fb15990ace21c38c31053749c5ef652adbc1603 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Fri, 10 May 2024 11:42:11 +0100 Subject: [PATCH] chore(internal): bytecode wrapping context We introduce a new mechanism of wrapping functions via a special context manager that is capable of capturing return values as well. The goal is to allow observability into the called functions, to have access to local variables on exit. This approach has the extra benefit of not introducing any extra frames in the call stack of the wrapped function. --- ddtrace/internal/assembly.py | 3 + ddtrace/internal/wrapping/context.py | 672 +++++++++++++++++++++++++++ tests/internal/test_wrapping.py | 192 ++++++++ 3 files changed, 867 insertions(+) create mode 100644 ddtrace/internal/wrapping/context.py diff --git a/ddtrace/internal/assembly.py b/ddtrace/internal/assembly.py index 9d502995a001..c1740192540d 100644 --- a/ddtrace/internal/assembly.py +++ b/ddtrace/internal/assembly.py @@ -272,3 +272,6 @@ def dis(self) -> None: def __iter__(self) -> t.Iterator[bc.Instr]: return iter(self._instrs) + + def __len__(self) -> int: + return len(self._instrs) diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py new file mode 100644 index 000000000000..b047f8b5aa1f --- /dev/null +++ b/ddtrace/internal/wrapping/context.py @@ -0,0 +1,672 @@ +from inspect import iscoroutinefunction +import sys +from types import FrameType +from types import FunctionType +import typing as t + + +try: + from typing import Protocol # noqa:F401 +except ImportError: + from typing_extensions import Protocol # type: ignore[assignment] + +import bytecode +from bytecode import Bytecode + +from ddtrace.internal.assembly import Assembly + + +T = t.TypeVar("T") + +# This module implements utilities for wrapping a function with a context +# manager. The rough idea is to re-write the function's bytecode to look like +# this: +# +# def foo(): +# with wrapping_context: +# # Original function code +# +# Because we also want to capture the return value, our context manager extends +# the Python one by implementing a __return__ method that will be called with +# the return value of the function. The __exit__ method is only called if the +# function raises an exception. +# +# Because CPython 3.11 introduced zero-cost exceptions, we cannot nest try +# blocks in the function's bytecode. In this case, we call the context manager +# methods directly at the right places, and set up the appropriate exception +# handling code. For older versions of Python we rely on the with statement to +# perform entry and exit operations. Calls to __return__ are explicit in all +# cases. +# +# Some advantages of wrapping a function this way are: +# - Access to the local variables on entry and on return/exit via the frame +# object. +# - No intermediate function calls that pollute the call stack. +# +# The actual bytecode wrapping is performed once on a target function via a +# universal wrapping context. Multiple context wrapping of a function is allowed +# and it is virtually implemented on top of the concrete universal wrapping +# context. This makes multiple wrapping/unwrapping easy, as it translates to a +# single bytecode wrapping/unwrapping operation. +# +# Context wrappers should be implemented as subclasses of the WrappingContext +# class. The __priority__ attribute can be used to control the order in which +# multiple context wrappers are entered and exited. The __frame__ attribute +# contains the frame object of the function being wrapped. The __enter__ and +# __exit__ methods should be implemented to perform the necessary operations. + +CONTEXT_HEAD = Assembly() +CONTEXT_RETURN = Assembly() +CONTEXT_FOOT = Assembly() + +if sys.version_info >= (3, 12): + CONTEXT_HEAD.parse( + r""" + push_null + load_const {context_enter} + call 0 + pop_top + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context_return} + push_null + swap 3 + call 1 + """ + ) + + CONTEXT_RETURN_CONST = Assembly() + CONTEXT_RETURN_CONST.parse( + r""" + push_null + load_const {context_return} + load_const {value} + call 1 + """ + ) + + CONTEXT_EXC_HEAD = Assembly() + CONTEXT_EXC_HEAD.parse( + r""" + push_null + load_const {context_exit} + call 0 + pop_top + """ + ) + + CONTEXT_FOOT.parse( + r""" + try @_except lasti + push_exc_info + push_null + load_const {context_exit} + call 0 + pop_top + reraise 2 + tried + + _except: + copy 3 + pop_except + reraise 1 + """ + ) + + +elif sys.version_info >= (3, 11): + CONTEXT_HEAD.parse( + r""" + push_null + load_const {context_enter} + precall 0 + call 0 + pop_top + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context_return} + push_null + swap 3 + precall 1 + call 1 + """ + ) + + CONTEXT_EXC_HEAD = Assembly() + CONTEXT_EXC_HEAD.parse( + r""" + push_null + load_const {context_exit} + precall 0 + call 0 + pop_top + """ + ) + + CONTEXT_FOOT.parse( + r""" + try @_except lasti + push_exc_info + push_null + load_const {context_exit} + precall 0 + call 0 + pop_top + reraise 2 + tried + + _except: + copy 3 + pop_except + reraise 1 + """ + ) + +elif sys.version_info >= (3, 10): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_except_start + pop_top + reraise 1 + """ + ) + +elif sys.version_info >= (3, 9): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_except_start + pop_top + reraise + """ + ) + + +elif sys.version_info >= (3, 7): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_cleanup_start + with_cleanup_finish + end_finally + load_const None + return_value + """ + ) + + +# This is abstract and should not be used directly +class BaseWrappingContext(t.ContextManager): + __priority__: int = 0 + __frame__: t.Optional[FrameType] = None + + def __init__(self, f: FunctionType): + self.__wrapped__ = f + + @classmethod + def wrapped(cls, f: FunctionType) -> "BaseWrappingContext": + if cls.is_wrapped(f): + context = cls.extract(f) + assert isinstance(context, cls) # nosec + else: + context = cls(f) + context.wrap() + return context + + def __return__(self, value): + return value + + @classmethod + def is_wrapped(cls, _f: FunctionType) -> bool: + raise NotImplementedError + + @classmethod + def extract(cls, _f: FunctionType) -> "BaseWrappingContext": + raise NotImplementedError + + def wrap(self) -> None: + raise NotImplementedError + + def unwrap(self) -> None: + raise NotImplementedError + + +# This is the public interface exported by this module +class WrappingContext(BaseWrappingContext): + @classmethod + def is_wrapped(cls, f: FunctionType) -> bool: + try: + return bool(cls.extract(f)) + except ValueError: + return False + + @classmethod + def extract(cls, f: FunctionType) -> "WrappingContext": + if _UniversalWrappingContext.is_wrapped(f): + try: + return _UniversalWrappingContext.extract(f).registered(cls) + except KeyError: + pass + msg = f"Function is not wrapped with {cls}" + raise ValueError(msg) + + def wrap(self) -> None: + t.cast(_UniversalWrappingContext, _UniversalWrappingContext.wrapped(self.__wrapped__)).register(self) + + def unwrap(self) -> None: + f = self.__wrapped__ + + if _UniversalWrappingContext.is_wrapped(f): + _UniversalWrappingContext.extract(f).unregister(self) + + +class ContextWrappedFunction(Protocol): + """A wrapped function.""" + + __dd_context_wrapped__ = None # type: t.Optional[_UniversalWrappingContext] + + def __call__(self, *args, **kwargs): + pass + + +# This class provides an interface between single bytecode wrapping and multiple +# logical context wrapping +class _UniversalWrappingContext(BaseWrappingContext): + def __init__(self, f: FunctionType) -> None: + super().__init__(f) + + self._contexts: t.Dict[t.Type[WrappingContext], WrappingContext] = {} + + def register(self, context: WrappingContext) -> None: + _type = type(context) + if _type in self._contexts: + raise ValueError("Context already registered") + + self._contexts[_type] = context + + def unregister(self, context: WrappingContext) -> None: + _type = type(context) + if _type not in self._contexts: + raise ValueError("Context not registered") + + del self._contexts[_type] + + if not self._contexts: + self.unwrap() + + def is_registered(self, context: WrappingContext) -> bool: + return type(context) in self._contexts + + def registered(self, context_type: t.Type[WrappingContext]) -> WrappingContext: + return self._contexts[context_type] + + def __enter__(self) -> "_UniversalWrappingContext": + frame = sys._getframe(1) + for context in sorted(self._contexts.values(), key=lambda c: c.__priority__): + context.__frame__ = frame + context.__enter__() + return self + + def _exit(self) -> None: + self.__exit__(*sys.exc_info()) + + def __exit__(self, *exc) -> None: + if exc == (None, None, None): + # In Python 3.7 this gets called when the context manager is exited + # normally + return + + for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + context.__exit__(*exc) + + def __return__(self, value: T) -> T: + for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + context.__return__(value) + return value + + @classmethod + def is_wrapped(cls, f: FunctionType) -> bool: + return hasattr(f, "__dd_context_wrapped__") + + @classmethod + def extract(cls, f: FunctionType) -> "_UniversalWrappingContext": + if not cls.is_wrapped(f): + raise ValueError("Function is not wrapped") + return t.cast(_UniversalWrappingContext, t.cast(ContextWrappedFunction, f).__dd_context_wrapped__) + + if sys.version_info >= (3, 11): + + def wrap(self) -> None: + f = self.__wrapped__ + + if self.is_wrapped(f): + raise ValueError("Function already wrapped") + + bc = Bytecode.from_code(f.__code__) + + # Prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN.bind({"context_return": self.__return__}, lineno=instr.lineno) + elif sys.version_info >= (3, 12) and instr.name == "RETURN_CONST": # Python 3.12+ + return_code = CONTEXT_RETURN_CONST.bind( + {"context_return": self.__return__, "value": instr.arg}, lineno=instr.lineno + ) + else: + return_code = [] + + bc[i:i] = return_code + i += len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Search for the RESUME instruction + for i, instr in enumerate(bc, 1): + try: + if instr.name == "RESUME": + break + except AttributeError: + # Not an instruction + pass + else: + i = 0 + + bc[i:i] = CONTEXT_HEAD.bind({"context_enter": self.__enter__}, lineno=f.__code__.co_firstlineno) + + # Add call to __exit__ in the exception handlers + exception_entries = set() + for instr in bc: + if isinstance(instr, bytecode.TryBegin): + exception_entries.add(instr.target) + + label_positions = [] + for i, instr in enumerate(bc, 1): + if isinstance(instr, bytecode.Label) and instr in exception_entries: + label_positions.append(i) + + for i in sorted(label_positions, reverse=True): + bc[i:i] = CONTEXT_EXC_HEAD.bind({"context_exit": self._exit}) + + # Wrap every line outside a try block + except_label = bytecode.Label() + first_try_begin = last_try_begin = bytecode.TryBegin(except_label, push_lasti=True) + + i = 0 + while i < len(bc): + instr = bc[i] + if isinstance(instr, bytecode.TryBegin) and last_try_begin is not None: + bc.insert(i, bytecode.TryEnd(last_try_begin)) + last_try_begin = None + i += 1 + elif isinstance(instr, bytecode.TryEnd): + # TODO: Skip wrapping around labels + if not isinstance(bc[i + 1], bytecode.TryBegin): + last_try_begin = bytecode.TryBegin(except_label, push_lasti=True) + bc.insert(i + 1, last_try_begin) + i += 1 + i += 1 + + bc.insert(0, first_try_begin) + + bc.append(bytecode.TryEnd(last_try_begin)) + bc.append(except_label) + bc.extend(CONTEXT_FOOT.bind({"context_exit": self._exit})) + + # Mark the function as wrapped by a wrapping context + t.cast(ContextWrappedFunction, f).__dd_context_wrapped__ = self + + # Replace the function code with the wrapped code + f.__code__ = bc.to_code() + + def unwrap(self) -> None: + f = self.__wrapped__ + + if not self.is_wrapped(f): + return + + wrapped = t.cast(ContextWrappedFunction, f) + + bc = Bytecode.from_code(f.__code__) + + # Remove the exception handling code + bc[-len(CONTEXT_FOOT) :] = [] + bc.pop() + bc.pop() + + except_label = bc.pop(0).target + + # Remove the try blocks + i = 0 + last_begin = None + while i < len(bc): + instr = bc[i] + if isinstance(instr, bytecode.TryBegin) and instr.target is except_label: + last_begin = bc.pop(i) + elif isinstance(instr, bytecode.TryEnd) and last_begin is not None and instr.entry is last_begin: + bc.pop(i) + last_begin = None + else: + i += 1 + + # Add call to __exit__ in the exception handlers + exception_entries = set() + for instr in bc: + if isinstance(instr, bytecode.TryBegin): + exception_entries.add(instr.target) + + label_positions = [] + for i, instr in enumerate(bc, 1): + if isinstance(instr, bytecode.Label) and instr in exception_entries: + label_positions.append(i) + + for i in sorted(label_positions, reverse=True): + bc[i : i + len(CONTEXT_EXC_HEAD)] = [] + + # Remove the head of the try block + wc = wrapped.__dd_context_wrapped__ + for i, instr in enumerate(bc): + try: + if instr.name == "LOAD_CONST" and instr.arg is wc: + break + except AttributeError: + # Not an instruction + pass + + # Search for the RESUME instruction + for i, instr in enumerate(bc, 1): + try: + if instr.name == "RESUME": + break + except AttributeError: + # Not an instruction + pass + else: + i = 0 + + bc[i : i + len(CONTEXT_HEAD)] = [] + + # Un-prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN + elif sys.version_info >= (3, 12) and instr.name == "RETURN_CONST": # Python 3.12+ + return_code = CONTEXT_RETURN_CONST + else: + return_code = None + + if return_code is not None: + bc[i - len(return_code) : i] = [] + i -= len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Recreate the code object + f.__code__ = bc.to_code() + + # Remove the wrapping context marker + del wrapped.__dd_context_wrapped__ + + else: + + def wrap(self) -> None: + f = self.__wrapped__ + + if self.is_wrapped(f): + raise ValueError("Function already wrapped") + + bc = Bytecode.from_code(f.__code__) + + # Prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN.bind({"context": self}, lineno=instr.lineno) + else: + return_code = [] + + bc[i:i] = return_code + i += len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Search for the GEN_START instruction + i = 0 + if sys.version_info >= (3, 10) and iscoroutinefunction(f): + for i, instr in enumerate(bc, 1): + try: + if instr.name == "GEN_START": + break + except AttributeError: + # Not an instruction + pass + + *bc[i:i], except_label = CONTEXT_HEAD.bind({"context": self}, lineno=f.__code__.co_firstlineno) + + bc.append(except_label) + bc.extend(CONTEXT_FOOT.bind()) + + # Mark the function as wrapped by a wrapping context + t.cast(ContextWrappedFunction, f).__dd_context_wrapped__ = self + + # Replace the function code with the wrapped code + f.__code__ = bc.to_code() + + def unwrap(self) -> None: + f = self.__wrapped__ + + if not self.is_wrapped(f): + return + + wrapped = t.cast(ContextWrappedFunction, f) + + bc = Bytecode.from_code(f.__code__) + + # Remove the exception handling code + bc[-len(CONTEXT_FOOT) :] = [] + bc.pop() + + # Remove the head of the try block + wc = wrapped.__dd_context_wrapped__ + for i, instr in enumerate(bc): + try: + if instr.name == "LOAD_CONST" and instr.arg is wc: + break + except AttributeError: + # Not an instruction + pass + + bc[i : i + len(CONTEXT_HEAD) - 1] = [] + + # Remove all the return handlers + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + bc[i - len(CONTEXT_RETURN) : i] = [] + i -= len(CONTEXT_RETURN) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Recreate the code object + f.__code__ = bc.to_code() + + # Remove the wrapping context marker + del wrapped.__dd_context_wrapped__ diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 3f7a73b19ef7..56f0a98a2643 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -7,6 +7,8 @@ from ddtrace.internal.wrapping import unwrap from ddtrace.internal.wrapping import wrap +from ddtrace.internal.wrapping.context import WrappingContext +from ddtrace.internal.wrapping.context import _UniversalWrappingContext def assert_stack(expected): @@ -491,3 +493,193 @@ def f(a, b, c=None): assert closure(1, 2, 3) == (1, 2, 3, 42) assert channel == [((42,), {}), closure, ((1, 2, 3), {}), (1, 2, 3, 42)] + + +NOTSET = object() + + +class DummyWrappingContext(WrappingContext): + def __init__(self, f): + super().__init__(f) + + self.entered = False + self.exited = False + self.return_value = NOTSET + self.exc_info = None + + def __enter__(self): + self.entered = True + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.exited = True + if exc_value is not None: + self.exc_info = (exc_type, exc_value, traceback) + + def __return__(self, value): + self.return_value = value + return value + + +def test_wrapping_context_happy(): + def foo(): + return 42 + + wc = DummyWrappingContext(foo) + wc.wrap() + + assert foo() == 42 + + assert wc.entered + assert wc.return_value == 42 + assert not wc.exited + assert wc.exc_info is None + + +def test_wrapping_context_unwrapping(): + def foo(): + return 42 + + wc = DummyWrappingContext(foo) + wc.wrap() + assert _UniversalWrappingContext.is_wrapped(foo) + + wc.unwrap() + assert not _UniversalWrappingContext.is_wrapped(foo) + + assert foo() == 42 + + assert not wc.entered + assert wc.return_value is NOTSET + assert not wc.exited + assert wc.exc_info is None + + +def test_wrapping_context_exc(): + def foo(): + raise ValueError("foo") + + wc = DummyWrappingContext(foo) + wc.wrap() + + with pytest.raises(ValueError): + foo() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type == ValueError + assert exc.args == ("foo",) + + +def test_wrapping_context_exc_on_exit(): + class BrokenExitWrappingContext(DummyWrappingContext): + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + raise RuntimeError("broken") + + def foo(): + raise ValueError("foo") + + wc = BrokenExitWrappingContext(foo) + wc.wrap() + + with pytest.raises(RuntimeError): + foo() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type == ValueError + assert exc.args == ("foo",) + + +def test_wrapping_context_priority(): + class HighPriorityWrappingContext(DummyWrappingContext): + def __enter__(self): + nonlocal mutated + + mutated = True + + return super().__enter__() + + def __return__(self, value): + nonlocal mutated + + assert not mutated + + return super().__return__(value) + + class LowPriorityWrappingContext(DummyWrappingContext): + __priority__ = 99 + + def __enter__(self): + nonlocal mutated + + assert mutated + + return super().__enter__() + + def __return__(self, value): + nonlocal mutated + + mutated = False + + return super().__return__(value) + + mutated = False + + def foo(): + return 42 + + hwc = HighPriorityWrappingContext(foo) + lwc = LowPriorityWrappingContext(foo) + + # Wrap low first. We want to make sure that hwc is entered first + lwc.wrap() + hwc.wrap() + + foo() + + assert lwc.entered + assert hwc.return_value == 42 + + +@pytest.mark.asyncio +async def test_wrapping_context_async_happy() -> None: + async def coro(): + return 1 + + wc = DummyWrappingContext(coro) + wc.wrap() + + assert await coro() == 1 + + assert wc.entered + assert wc.return_value == 1 + assert not wc.exited + assert wc.exc_info is None + + +@pytest.mark.asyncio +async def test_wrapping_context_async_exc() -> None: + async def coro(): + raise ValueError("foo") + + wc = DummyWrappingContext(coro) + wc.wrap() + + with pytest.raises(ValueError): + await coro() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type is ValueError + assert exc.args == ("foo",)