From 27348c07ee7ebcd8d061fc59c9e488dc914759f8 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Fri, 10 May 2024 11:42:11 +0100 Subject: [PATCH 1/6] chore(internal): bytecode wrapping context We introduce a new mechanism of wrapping functions via a special context manager that is capable of capturing return values as well. The goal is to allow observability into the called functions, to have access to local variables on exit. This approach has the extra benefit of not introducing any extra frames in the call stack of the wrapped function. --- ddtrace/internal/assembly.py | 3 + ddtrace/internal/wrapping/context.py | 651 +++++++++++++++++++++++++++ tests/internal/test_wrapping.py | 197 ++++++++ 3 files changed, 851 insertions(+) create mode 100644 ddtrace/internal/wrapping/context.py diff --git a/ddtrace/internal/assembly.py b/ddtrace/internal/assembly.py index 9d502995a00..c1740192540 100644 --- a/ddtrace/internal/assembly.py +++ b/ddtrace/internal/assembly.py @@ -272,3 +272,6 @@ def dis(self) -> None: def __iter__(self) -> t.Iterator[bc.Instr]: return iter(self._instrs) + + def __len__(self) -> int: + return len(self._instrs) diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py new file mode 100644 index 00000000000..852da32dfe1 --- /dev/null +++ b/ddtrace/internal/wrapping/context.py @@ -0,0 +1,651 @@ +from inspect import iscoroutinefunction +import sys +from types import FrameType +from types import FunctionType +import typing as t + + +try: + from typing import Protocol # noqa:F401 +except ImportError: + from typing_extensions import Protocol # type: ignore[assignment] + +try: + from typing import override # type: ignore[attr-defined] +except ImportError: + from typing_extensions import override + + +import bytecode +from bytecode import Bytecode + +from ddtrace.internal.assembly import Assembly + + +T = t.TypeVar("T") + +# This module implements utilities for wrapping a function with a context +# manager. The rough idea is to re-write the function's bytecode to look like +# this: +# +# def foo(): +# with wrapping_context: +# # Original function code +# +# Because we also want to capture the return value, our context manager extends +# the Python one by implementing a __return__ method that will be called with +# the return value of the function. The __exit__ method is only called if the +# function raises an exception. +# +# Because CPython 3.11 introduced zero-cost exceptions, we cannot nest try +# blocks in the function's bytecode. In this case, we call the context manager +# methods directly at the right places, and set up the appropriate exception +# handling code. For older versions of Python we rely on the with statement to +# perform entry and exit operations. Calls to __return__ are explicit in all +# cases. +# +# Some advantages of wrapping a function this way are: +# - Access to the local variables on entry and on return/exit via the frame +# object. +# - No intermediate function calls that pollute the call stack. +# - No need to call the wrapped function manually. +# +# The actual bytecode wrapping is performed once on a target function via a +# universal wrapping context. Multiple context wrapping of a function is allowed +# and it is virtually implemented on top of the concrete universal wrapping +# context. This makes multiple wrapping/unwrapping easy, as it translates to a +# single bytecode wrapping/unwrapping operation. +# +# Context wrappers should be implemented as subclasses of the WrappingContext +# class. The __priority__ attribute can be used to control the order in which +# multiple context wrappers are entered and exited. The __enter__ and __exit__ +# methods should be implemented to perform the necessary operations. The +# __enter__ method takes an extra argument representing the frame object of the +# wrapped function. The __return__ method can be implemented to capture the +# return value of the wrapped function. If implemented, its return value will be +# used as the wrapped function return value. The wrapped function can be +# accessed via the __wrapped__ attribute. + +CONTEXT_HEAD = Assembly() +CONTEXT_RETURN = Assembly() +CONTEXT_FOOT = Assembly() + +if sys.version_info >= (3, 12): + CONTEXT_HEAD.parse( + r""" + push_null + load_const {context_enter} + call 0 + pop_top + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context_return} + push_null + swap 3 + call 1 + """ + ) + + CONTEXT_RETURN_CONST = Assembly() + CONTEXT_RETURN_CONST.parse( + r""" + push_null + load_const {context_return} + load_const {value} + call 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + try @_except lasti + push_exc_info + push_null + load_const {context_exit} + call 0 + pop_top + reraise 2 + tried + + _except: + copy 3 + pop_except + reraise 1 + """ + ) + + +elif sys.version_info >= (3, 11): + CONTEXT_HEAD.parse( + r""" + push_null + load_const {context_enter} + precall 0 + call 0 + pop_top + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context_return} + push_null + swap 3 + precall 1 + call 1 + """ + ) + + CONTEXT_EXC_HEAD = Assembly() + CONTEXT_EXC_HEAD.parse( + r""" + push_null + load_const {context_exit} + precall 0 + call 0 + pop_top + """ + ) + + CONTEXT_FOOT.parse( + r""" + try @_except lasti + push_exc_info + push_null + load_const {context_exit} + precall 0 + call 0 + pop_top + reraise 2 + tried + + _except: + copy 3 + pop_except + reraise 1 + """ + ) + +elif sys.version_info >= (3, 10): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_except_start + pop_top + reraise 1 + """ + ) + +elif sys.version_info >= (3, 9): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_except_start + pop_top + reraise + """ + ) + + +elif sys.version_info >= (3, 7): + CONTEXT_HEAD.parse( + r""" + load_const {context} + setup_with @_except + pop_top + _except: + """ + ) + + CONTEXT_RETURN.parse( + r""" + load_const {context} + load_method $__return__ + rot_three + rot_three + call_method 1 + """ + ) + + CONTEXT_FOOT.parse( + r""" + with_cleanup_start + with_cleanup_finish + end_finally + load_const None + return_value + """ + ) + + +# This is abstract and should not be used directly +class BaseWrappingContext(t.ContextManager): + __priority__: int = 0 + __frame__: t.Optional[FrameType] = None + + def __init__(self, f: FunctionType): + self.__wrapped__ = f + + @classmethod + def wrapped(cls, f: FunctionType) -> "BaseWrappingContext": + if cls.is_wrapped(f): + context = cls.extract(f) + assert isinstance(context, cls) # nosec + else: + context = cls(f) + context.wrap() + return context + + def __return__(self, value): + return value + + @classmethod + def is_wrapped(cls, _f: FunctionType) -> bool: + raise NotImplementedError + + @classmethod + def extract(cls, _f: FunctionType) -> "BaseWrappingContext": + raise NotImplementedError + + def wrap(self) -> None: + raise NotImplementedError + + def unwrap(self) -> None: + raise NotImplementedError + + +# This is the public interface exported by this module +class WrappingContext(BaseWrappingContext): + @override + def __enter__(self, _frame: FrameType) -> "WrappingContext": + raise NotImplementedError + + @classmethod + def is_wrapped(cls, f: FunctionType) -> bool: + try: + return bool(cls.extract(f)) + except ValueError: + return False + + @classmethod + def extract(cls, f: FunctionType) -> "WrappingContext": + if _UniversalWrappingContext.is_wrapped(f): + try: + return _UniversalWrappingContext.extract(f).registered(cls) + except KeyError: + pass + msg = f"Function is not wrapped with {cls}" + raise ValueError(msg) + + def wrap(self) -> None: + t.cast(_UniversalWrappingContext, _UniversalWrappingContext.wrapped(self.__wrapped__)).register(self) + + def unwrap(self) -> None: + f = self.__wrapped__ + + if _UniversalWrappingContext.is_wrapped(f): + _UniversalWrappingContext.extract(f).unregister(self) + + +class ContextWrappedFunction(Protocol): + """A wrapped function.""" + + __dd_context_wrapped__ = None # type: t.Optional[_UniversalWrappingContext] + + def __call__(self, *args, **kwargs): + pass + + +# This class provides an interface between single bytecode wrapping and multiple +# logical context wrapping +class _UniversalWrappingContext(BaseWrappingContext): + def __init__(self, f: FunctionType) -> None: + super().__init__(f) + + self._contexts: t.Dict[t.Type[WrappingContext], WrappingContext] = {} + + def register(self, context: WrappingContext) -> None: + _type = type(context) + if _type in self._contexts: + raise ValueError("Context already registered") + + self._contexts[_type] = context + + def unregister(self, context: WrappingContext) -> None: + _type = type(context) + if _type not in self._contexts: + raise ValueError("Context not registered") + + del self._contexts[_type] + + if not self._contexts: + self.unwrap() + + def is_registered(self, context: WrappingContext) -> bool: + return type(context) in self._contexts + + def registered(self, context_type: t.Type[WrappingContext]) -> WrappingContext: + return self._contexts[context_type] + + def __enter__(self) -> "_UniversalWrappingContext": + frame = sys._getframe(1) + for context in sorted(self._contexts.values(), key=lambda c: c.__priority__): + context.__enter__(frame) + return self + + def _exit(self) -> None: + self.__exit__(*sys.exc_info()) + + def __exit__(self, *exc) -> None: + if exc == (None, None, None): + # In Python 3.7 this gets called when the context manager is exited + # normally + return + + for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + context.__exit__(*exc) + + def __return__(self, value: T) -> T: + for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + context.__return__(value) + return value + + @classmethod + def is_wrapped(cls, f: FunctionType) -> bool: + return hasattr(f, "__dd_context_wrapped__") + + @classmethod + def extract(cls, f: FunctionType) -> "_UniversalWrappingContext": + if not cls.is_wrapped(f): + raise ValueError("Function is not wrapped") + return t.cast(_UniversalWrappingContext, t.cast(ContextWrappedFunction, f).__dd_context_wrapped__) + + if sys.version_info >= (3, 11): + + def wrap(self) -> None: + f = self.__wrapped__ + + if self.is_wrapped(f): + raise ValueError("Function already wrapped") + + bc = Bytecode.from_code(f.__code__) + + # Prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN.bind({"context_return": self.__return__}, lineno=instr.lineno) + elif sys.version_info >= (3, 12) and instr.name == "RETURN_CONST": # Python 3.12+ + return_code = CONTEXT_RETURN_CONST.bind( + {"context_return": self.__return__, "value": instr.arg}, lineno=instr.lineno + ) + else: + return_code = [] + + bc[i:i] = return_code + i += len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Search for the RESUME instruction + for i, instr in enumerate(bc, 1): + try: + if instr.name == "RESUME": + break + except AttributeError: + # Not an instruction + pass + else: + i = 0 + + bc[i:i] = CONTEXT_HEAD.bind({"context_enter": self.__enter__}, lineno=f.__code__.co_firstlineno) + + # Wrap every line outside a try block + except_label = bytecode.Label() + first_try_begin = last_try_begin = bytecode.TryBegin(except_label, push_lasti=True) + + i = 0 + while i < len(bc): + instr = bc[i] + if isinstance(instr, bytecode.TryBegin) and last_try_begin is not None: + bc.insert(i, bytecode.TryEnd(last_try_begin)) + last_try_begin = None + i += 1 + elif isinstance(instr, bytecode.TryEnd): + j = i + 1 + while j < len(bc) and not isinstance(bc[j], bytecode.TryBegin): + if isinstance(bc[j], bytecode.Instr): + last_try_begin = bytecode.TryBegin(except_label, push_lasti=True) + bc.insert(i + 1, last_try_begin) + break + j += 1 + i += 1 + i += 1 + + bc.insert(0, first_try_begin) + + bc.append(bytecode.TryEnd(last_try_begin)) + bc.append(except_label) + bc.extend(CONTEXT_FOOT.bind({"context_exit": self._exit})) + + # Mark the function as wrapped by a wrapping context + t.cast(ContextWrappedFunction, f).__dd_context_wrapped__ = self + + # Replace the function code with the wrapped code + f.__code__ = bc.to_code() + + def unwrap(self) -> None: + f = self.__wrapped__ + + if not self.is_wrapped(f): + return + + wrapped = t.cast(ContextWrappedFunction, f) + + bc = Bytecode.from_code(f.__code__) + + # Remove the exception handling code + bc[-len(CONTEXT_FOOT) :] = [] + bc.pop() + bc.pop() + + except_label = bc.pop(0).target + + # Remove the try blocks + i = 0 + last_begin = None + while i < len(bc): + instr = bc[i] + if isinstance(instr, bytecode.TryBegin) and instr.target is except_label: + last_begin = bc.pop(i) + elif isinstance(instr, bytecode.TryEnd) and last_begin is not None and instr.entry is last_begin: + bc.pop(i) + last_begin = None + else: + i += 1 + + # Remove the head of the try block + wc = wrapped.__dd_context_wrapped__ + for i, instr in enumerate(bc): + try: + if instr.name == "LOAD_CONST" and instr.arg is wc: + break + except AttributeError: + # Not an instruction + pass + + # Search for the RESUME instruction + for i, instr in enumerate(bc, 1): + try: + if instr.name == "RESUME": + break + except AttributeError: + # Not an instruction + pass + else: + i = 0 + + bc[i : i + len(CONTEXT_HEAD)] = [] + + # Un-prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN + elif sys.version_info >= (3, 12) and instr.name == "RETURN_CONST": # Python 3.12+ + return_code = CONTEXT_RETURN_CONST + else: + return_code = None + + if return_code is not None: + bc[i - len(return_code) : i] = [] + i -= len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Recreate the code object + f.__code__ = bc.to_code() + + # Remove the wrapping context marker + del wrapped.__dd_context_wrapped__ + + else: + + def wrap(self) -> None: + f = self.__wrapped__ + + if self.is_wrapped(f): + raise ValueError("Function already wrapped") + + bc = Bytecode.from_code(f.__code__) + + # Prefix every return + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + return_code = CONTEXT_RETURN.bind({"context": self}, lineno=instr.lineno) + else: + return_code = [] + + bc[i:i] = return_code + i += len(return_code) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Search for the GEN_START instruction + i = 0 + if sys.version_info >= (3, 10) and iscoroutinefunction(f): + for i, instr in enumerate(bc, 1): + try: + if instr.name == "GEN_START": + break + except AttributeError: + # Not an instruction + pass + + *bc[i:i], except_label = CONTEXT_HEAD.bind({"context": self}, lineno=f.__code__.co_firstlineno) + + bc.append(except_label) + bc.extend(CONTEXT_FOOT.bind()) + + # Mark the function as wrapped by a wrapping context + t.cast(ContextWrappedFunction, f).__dd_context_wrapped__ = self + + # Replace the function code with the wrapped code + f.__code__ = bc.to_code() + + def unwrap(self) -> None: + f = self.__wrapped__ + + if not self.is_wrapped(f): + return + + wrapped = t.cast(ContextWrappedFunction, f) + + bc = Bytecode.from_code(f.__code__) + + # Remove the exception handling code + bc[-len(CONTEXT_FOOT) :] = [] + bc.pop() + + # Remove the head of the try block + wc = wrapped.__dd_context_wrapped__ + for i, instr in enumerate(bc): + try: + if instr.name == "LOAD_CONST" and instr.arg is wc: + break + except AttributeError: + # Not an instruction + pass + + bc[i : i + len(CONTEXT_HEAD) - 1] = [] + + # Remove all the return handlers + i = 0 + while i < len(bc): + instr = bc[i] + try: + if instr.name == "RETURN_VALUE": + bc[i - len(CONTEXT_RETURN) : i] = [] + i -= len(CONTEXT_RETURN) + except AttributeError: + # Not an instruction + pass + i += 1 + + # Recreate the code object + f.__code__ = bc.to_code() + + # Remove the wrapping context marker + del wrapped.__dd_context_wrapped__ diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 3f7a73b19ef..5f9b8e77720 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -7,6 +7,8 @@ from ddtrace.internal.wrapping import unwrap from ddtrace.internal.wrapping import wrap +from ddtrace.internal.wrapping.context import WrappingContext +from ddtrace.internal.wrapping.context import _UniversalWrappingContext def assert_stack(expected): @@ -491,3 +493,198 @@ def f(a, b, c=None): assert closure(1, 2, 3) == (1, 2, 3, 42) assert channel == [((42,), {}), closure, ((1, 2, 3), {}), (1, 2, 3, 42)] + + +NOTSET = object() + + +class DummyWrappingContext(WrappingContext): + def __init__(self, f): + super().__init__(f) + + self.entered = False + self.exited = False + self.return_value = NOTSET + self.exc_info = None + self.frame = None + + def __enter__(self, frame): + self.entered = True + self.frame = frame + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.exited = True + if exc_value is not None: + self.exc_info = (exc_type, exc_value, traceback) + + def __return__(self, value): + self.return_value = value + return value + + +def test_wrapping_context_happy(): + def foo(): + return 42 + + wc = DummyWrappingContext(foo) + wc.wrap() + + assert foo() == 42 + + assert wc.entered + assert wc.return_value == 42 + assert not wc.exited + assert wc.exc_info is None + + assert wc.frame.f_code.co_name == "foo" + assert wc.frame.f_code.co_filename == __file__ + + +def test_wrapping_context_unwrapping(): + def foo(): + return 42 + + wc = DummyWrappingContext(foo) + wc.wrap() + assert _UniversalWrappingContext.is_wrapped(foo) + + wc.unwrap() + assert not _UniversalWrappingContext.is_wrapped(foo) + + assert foo() == 42 + + assert not wc.entered + assert wc.return_value is NOTSET + assert not wc.exited + assert wc.exc_info is None + + +def test_wrapping_context_exc(): + def foo(): + raise ValueError("foo") + + wc = DummyWrappingContext(foo) + wc.wrap() + + with pytest.raises(ValueError): + foo() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type == ValueError + assert exc.args == ("foo",) + + +def test_wrapping_context_exc_on_exit(): + class BrokenExitWrappingContext(DummyWrappingContext): + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + raise RuntimeError("broken") + + def foo(): + raise ValueError("foo") + + wc = BrokenExitWrappingContext(foo) + wc.wrap() + + with pytest.raises(RuntimeError): + foo() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type == ValueError + assert exc.args == ("foo",) + + +def test_wrapping_context_priority(): + class HighPriorityWrappingContext(DummyWrappingContext): + def __enter__(self, frame): + nonlocal mutated + + mutated = True + + return super().__enter__(frame) + + def __return__(self, value): + nonlocal mutated + + assert not mutated + + return super().__return__(value) + + class LowPriorityWrappingContext(DummyWrappingContext): + __priority__ = 99 + + def __enter__(self, frame): + nonlocal mutated + + assert mutated + + return super().__enter__(frame) + + def __return__(self, value): + nonlocal mutated + + mutated = False + + return super().__return__(value) + + mutated = False + + def foo(): + return 42 + + hwc = HighPriorityWrappingContext(foo) + lwc = LowPriorityWrappingContext(foo) + + # Wrap low first. We want to make sure that hwc is entered first + lwc.wrap() + hwc.wrap() + + foo() + + assert lwc.entered + assert hwc.return_value == 42 + + +@pytest.mark.asyncio +async def test_wrapping_context_async_happy() -> None: + async def coro(): + return 1 + + wc = DummyWrappingContext(coro) + wc.wrap() + + assert await coro() == 1 + + assert wc.entered + assert wc.return_value == 1 + assert not wc.exited + assert wc.exc_info is None + + +@pytest.mark.asyncio +async def test_wrapping_context_async_exc() -> None: + async def coro(): + raise ValueError("foo") + + wc = DummyWrappingContext(coro) + wc.wrap() + + with pytest.raises(ValueError): + await coro() + + assert wc.entered + assert wc.return_value is NOTSET + assert wc.exited + + _type, exc, _ = wc.exc_info + assert _type is ValueError + assert exc.args == ("foo",) From 591f34f595b71533c23fcac9c5f7f1964d780761 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Thu, 16 May 2024 14:24:09 +0100 Subject: [PATCH 2/6] add context storage --- ddtrace/internal/wrapping/context.py | 77 ++++++++++++++++++------- tests/internal/test_wrapping.py | 86 +++++++++++++++++++++++++--- 2 files changed, 134 insertions(+), 29 deletions(-) diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py index 852da32dfe1..216485adf94 100644 --- a/ddtrace/internal/wrapping/context.py +++ b/ddtrace/internal/wrapping/context.py @@ -1,7 +1,9 @@ +from contextvars import ContextVar from inspect import iscoroutinefunction import sys from types import FrameType from types import FunctionType +from types import TracebackType import typing as t @@ -10,12 +12,6 @@ except ImportError: from typing_extensions import Protocol # type: ignore[assignment] -try: - from typing import override # type: ignore[attr-defined] -except ImportError: - from typing_extensions import override - - import bytecode from bytecode import Bytecode @@ -60,11 +56,13 @@ # class. The __priority__ attribute can be used to control the order in which # multiple context wrappers are entered and exited. The __enter__ and __exit__ # methods should be implemented to perform the necessary operations. The -# __enter__ method takes an extra argument representing the frame object of the -# wrapped function. The __return__ method can be implemented to capture the -# return value of the wrapped function. If implemented, its return value will be -# used as the wrapped function return value. The wrapped function can be -# accessed via the __wrapped__ attribute. +# __exit__ method is called if the wrapped function raises an exception. The +# frame of the wrapped function can be accessed via the __frame__ property. The +# __return__ method can be implemented to capture the return value of the +# wrapped function. If implemented, its return value will be used as the wrapped +# function return value. The wrapped function can be accessed via the +# __wrapped__ attribute. Context-specific values can be stored and retrieved +# with the set and get methods. CONTEXT_HEAD = Assembly() CONTEXT_RETURN = Assembly() @@ -260,10 +258,36 @@ # This is abstract and should not be used directly class BaseWrappingContext(t.ContextManager): __priority__: int = 0 - __frame__: t.Optional[FrameType] = None def __init__(self, f: FunctionType): self.__wrapped__ = f + self._storage_stack: ContextVar[list[dict]] = ContextVar(f"{type(self).__name__}__storage_stack", default=[]) + + def __enter__(self) -> "BaseWrappingContext": + self._storage_stack.get().append({}) + return self + + def _pop_storage(self) -> t.Dict[str, t.Any]: + return self._storage_stack.get().pop() + + def __return__(self, value: T) -> T: + self._pop_storage() + return value + + def __exit__( + self, + exc_type: t.Optional[t.Type[BaseException]], + exc_val: t.Optional[BaseException], + exc_tb: t.Optional[TracebackType], + ) -> None: + self._pop_storage() + + def get(self, key: str) -> t.Any: + return self._storage_stack.get()[-1][key] + + def set(self, key: str, value: T) -> T: + self._storage_stack.get()[-1][key] = value + return value @classmethod def wrapped(cls, f: FunctionType) -> "BaseWrappingContext": @@ -275,9 +299,6 @@ def wrapped(cls, f: FunctionType) -> "BaseWrappingContext": context.wrap() return context - def __return__(self, value): - return value - @classmethod def is_wrapped(cls, _f: FunctionType) -> bool: raise NotImplementedError @@ -295,9 +316,15 @@ def unwrap(self) -> None: # This is the public interface exported by this module class WrappingContext(BaseWrappingContext): - @override - def __enter__(self, _frame: FrameType) -> "WrappingContext": - raise NotImplementedError + @property + def __frame__(self) -> FrameType: + try: + return _UniversalWrappingContext.extract(self.__wrapped__).get("__frame__") + except ValueError: + raise AttributeError("Wrapping context not entered") + + def get_local(self, name: str) -> t.Any: + return self.__frame__.f_locals[name] @classmethod def is_wrapped(cls, f: FunctionType) -> bool: @@ -367,9 +394,14 @@ def registered(self, context_type: t.Type[WrappingContext]) -> WrappingContext: return self._contexts[context_type] def __enter__(self) -> "_UniversalWrappingContext": - frame = sys._getframe(1) + super().__enter__() + + # Make the frame object available to the contexts + self.set("__frame__", sys._getframe(1)) + for context in sorted(self._contexts.values(), key=lambda c: c.__priority__): - context.__enter__(frame) + context.__enter__() + return self def _exit(self) -> None: @@ -384,10 +416,13 @@ def __exit__(self, *exc) -> None: for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): context.__exit__(*exc) + super().__exit__(*exc) + def __return__(self, value: T) -> T: for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): context.__return__(value) - return value + + return super().__return__(value) @classmethod def is_wrapped(cls, f: FunctionType) -> bool: diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 5f9b8e77720..3d7a7b563f8 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager import inspect import sys @@ -508,19 +509,20 @@ def __init__(self, f): self.exc_info = None self.frame = None - def __enter__(self, frame): + def __enter__(self): self.entered = True - self.frame = frame - return self + self.frame = self.__frame__ + return super().__enter__() def __exit__(self, exc_type, exc_value, traceback): self.exited = True if exc_value is not None: self.exc_info = (exc_type, exc_value, traceback) + super().__exit__(exc_type, exc_value, traceback) def __return__(self, value): self.return_value = value - return value + return super().__return__(value) def test_wrapping_context_happy(): @@ -605,12 +607,12 @@ def foo(): def test_wrapping_context_priority(): class HighPriorityWrappingContext(DummyWrappingContext): - def __enter__(self, frame): + def __enter__(self): nonlocal mutated mutated = True - return super().__enter__(frame) + return super().__enter__() def __return__(self, value): nonlocal mutated @@ -622,12 +624,12 @@ def __return__(self, value): class LowPriorityWrappingContext(DummyWrappingContext): __priority__ = 99 - def __enter__(self, frame): + def __enter__(self): nonlocal mutated assert mutated - return super().__enter__(frame) + return super().__enter__() def __return__(self, value): nonlocal mutated @@ -654,6 +656,40 @@ def foo(): assert hwc.return_value == 42 +def test_wrapping_context_recursive(): + values = [] + + class RecursiveWrappingContext(DummyWrappingContext): + def __enter__(self): + nonlocal values + super().__enter__() + + n = self.__frame__.f_locals["n"] + self.set("n", n) + values.append(n) + + return self + + def __return__(self, value): + nonlocal values + n = self.__frame__.f_locals["n"] + assert self.get("n") == n + values.append(n) + + return super().__return__(value) + + def factorial(n): + if n == 0: + return 1 + return n * factorial(n - 1) + + wc = RecursiveWrappingContext(factorial) + wc.wrap() + + assert factorial(5) == 120 + assert values == [5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5] + + @pytest.mark.asyncio async def test_wrapping_context_async_happy() -> None: async def coro(): @@ -688,3 +724,37 @@ async def coro(): _type, exc, _ = wc.exc_info assert _type is ValueError assert exc.args == ("foo",) + + +@pytest.mark.asyncio +async def test_wrapping_context_async_concurrent() -> None: + values = [] + + class ConcurrentWrappingContext(DummyWrappingContext): + def __enter__(self): + super().__enter__() + + self.set("n", self.__frame__.f_locals["n"]) + + return self + + def __return__(self, value): + nonlocal values + + values.append((self.get("n"), self.__frame__.f_locals["n"])) + + return super().__return__(value) + + async def fibonacci(n): + if n <= 1: + return 1 + return sum(await asyncio.gather(fibonacci(n - 1), fibonacci(n - 2))) + + wc = ConcurrentWrappingContext(fibonacci) + wc.wrap() + + N = 20 + + await asyncio.gather(*[fibonacci(n) for n in range(1, N)]) + + assert set(values) == {(n, n) for n in range(0, N)} From 7ebe136b7ecc16479b1c812f5ddf2976c540051e Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Fri, 17 May 2024 15:50:20 +0100 Subject: [PATCH 3/6] fix 3.11+ unwrapping --- ddtrace/internal/wrapping/context.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py index 216485adf94..97d5955976b 100644 --- a/ddtrace/internal/wrapping/context.py +++ b/ddtrace/internal/wrapping/context.py @@ -531,14 +531,12 @@ def unwrap(self) -> None: # Remove the try blocks i = 0 - last_begin = None while i < len(bc): instr = bc[i] if isinstance(instr, bytecode.TryBegin) and instr.target is except_label: - last_begin = bc.pop(i) - elif isinstance(instr, bytecode.TryEnd) and last_begin is not None and instr.entry is last_begin: bc.pop(i) - last_begin = None + elif isinstance(instr, bytecode.TryEnd) and instr.entry.target is except_label: + bc.pop(i) else: i += 1 From 870576ae65ec7ba3dddcfc37b3ecd4546ff155d8 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Mon, 20 May 2024 12:27:13 +0100 Subject: [PATCH 4/6] move sorting to registration --- ddtrace/internal/wrapping/context.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py index 97d5955976b..24fd91d483e 100644 --- a/ddtrace/internal/wrapping/context.py +++ b/ddtrace/internal/wrapping/context.py @@ -368,22 +368,22 @@ class _UniversalWrappingContext(BaseWrappingContext): def __init__(self, f: FunctionType) -> None: super().__init__(f) - self._contexts: t.Dict[t.Type[WrappingContext], WrappingContext] = {} + self._contexts: t.List[WrappingContext] = [] def register(self, context: WrappingContext) -> None: _type = type(context) - if _type in self._contexts: + if any(isinstance(c, _type) for c in self._contexts): raise ValueError("Context already registered") - self._contexts[_type] = context + self._contexts.append(context) + self._contexts.sort(key=lambda c: c.__priority__) def unregister(self, context: WrappingContext) -> None: - _type = type(context) - if _type not in self._contexts: + try: + self._contexts.remove(context) + except ValueError: raise ValueError("Context not registered") - del self._contexts[_type] - if not self._contexts: self.unwrap() @@ -391,7 +391,10 @@ def is_registered(self, context: WrappingContext) -> bool: return type(context) in self._contexts def registered(self, context_type: t.Type[WrappingContext]) -> WrappingContext: - return self._contexts[context_type] + for context in self._contexts: + if isinstance(context, context_type): + return context + raise KeyError(f"Context {context_type} not registered") def __enter__(self) -> "_UniversalWrappingContext": super().__enter__() @@ -399,7 +402,7 @@ def __enter__(self) -> "_UniversalWrappingContext": # Make the frame object available to the contexts self.set("__frame__", sys._getframe(1)) - for context in sorted(self._contexts.values(), key=lambda c: c.__priority__): + for context in self._contexts: context.__enter__() return self @@ -413,13 +416,13 @@ def __exit__(self, *exc) -> None: # normally return - for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + for context in self._contexts[::-1]: context.__exit__(*exc) super().__exit__(*exc) def __return__(self, value: T) -> T: - for context in sorted(self._contexts.values(), key=lambda c: -c.__priority__): + for context in self._contexts[::-1]: context.__return__(value) return super().__return__(value) From ecaa8444fade92b0326afa628c909819abe671a6 Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Mon, 20 May 2024 12:31:24 +0100 Subject: [PATCH 5/6] add generator test case --- tests/internal/test_wrapping.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 3d7a7b563f8..8601c6ccecd 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -690,6 +690,22 @@ def factorial(n): assert values == [5, 4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 5] +def test_wrapping_context_generator(): + def foo(): + yield from range(10) + return 42 + + wc = DummyWrappingContext(foo) + wc.wrap() + + assert list(foo()) == list(range(10)) + + assert wc.entered + assert wc.return_value == 42 + assert not wc.exited + assert wc.exc_info is None + + @pytest.mark.asyncio async def test_wrapping_context_async_happy() -> None: async def coro(): From f8e0cfafa6f8cafaa3881eccc9ca470ae64bde3e Mon Sep 17 00:00:00 2001 From: "Gabriele N. Tornetta" Date: Tue, 21 May 2024 10:46:27 +0100 Subject: [PATCH 6/6] add async generator test case --- tests/internal/test_wrapping.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 8601c6ccecd..d27eadac43b 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -706,6 +706,28 @@ def foo(): assert wc.exc_info is None +@pytest.mark.asyncio +async def test_wrapping_context_async_generator(): + async def arange(count): + for i in range(count): + yield (i) + await asyncio.sleep(0.0) + + wc = DummyWrappingContext(arange) + wc.wrap() + + a = [] + async for _ in arange(10): + a.append(_) + + assert a == list(range(10)) + + assert wc.entered + assert wc.return_value is None + assert not wc.exited + assert wc.exc_info is None + + @pytest.mark.asyncio async def test_wrapping_context_async_happy() -> None: async def coro():