diff --git a/ddtrace/internal/wrapping/context.py b/ddtrace/internal/wrapping/context.py index 2e5be4b1013..2389631106a 100644 --- a/ddtrace/internal/wrapping/context.py +++ b/ddtrace/internal/wrapping/context.py @@ -13,7 +13,13 @@ from bytecode import Bytecode from ddtrace.internal.assembly import Assembly +from ddtrace.internal.forksafe import Lock from ddtrace.internal.utils.inspection import link_function_to_code +from ddtrace.internal.wrapping import WrappedFunction +from ddtrace.internal.wrapping import Wrapper +from ddtrace.internal.wrapping import is_wrapped_with +from ddtrace.internal.wrapping import unwrap +from ddtrace.internal.wrapping import wrap T = t.TypeVar("T") @@ -406,6 +412,44 @@ def unwrap(self) -> None: _UniversalWrappingContext.extract(f).unregister(self) +class LazyWrappingContext(WrappingContext): + def __init__(self, f: FunctionType): + super().__init__(f) + + self._trampoline: t.Optional[Wrapper] = None + self._trampoline_lock = Lock() + + def wrap(self) -> None: + """Perform the bytecode wrapping on first invocation.""" + with (tl := self._trampoline_lock): + if self._trampoline is not None: + return + + def trampoline(_, args, kwargs): + with tl: + f = t.cast(WrappedFunction, self.__wrapped__) + if is_wrapped_with(self.__wrapped__, trampoline): + f = unwrap(f, trampoline) + super(LazyWrappingContext, self).wrap() + return f(*args, **kwargs) + + wrap(self.__wrapped__, trampoline) + + self._trampoline = trampoline + + def unwrap(self) -> None: + with self._trampoline_lock: + if self._trampoline is None: + return + + if self.is_wrapped(self.__wrapped__): + super().unwrap() + else: + unwrap(t.cast(WrappedFunction, self.__wrapped__), self._trampoline) + + self._trampoline = None + + class ContextWrappedFunction(Protocol): """A wrapped function.""" diff --git a/tests/internal/test_wrapping.py b/tests/internal/test_wrapping.py index 3610f0d452a..40e048f5bf3 100644 --- a/tests/internal/test_wrapping.py +++ b/tests/internal/test_wrapping.py @@ -10,6 +10,7 @@ from ddtrace.internal.wrapping import is_wrapped_with from ddtrace.internal.wrapping import unwrap from ddtrace.internal.wrapping import wrap +from ddtrace.internal.wrapping.context import LazyWrappingContext from ddtrace.internal.wrapping.context import WrappingContext from ddtrace.internal.wrapping.context import _UniversalWrappingContext @@ -926,3 +927,42 @@ def foo(): new_method_count = len([_ for _ in gc.get_objects() if type(_).__name__ == "method"]) assert new_method_count <= method_count + 1 + + +def test_wrapping_context_lazy(): + free = 42 + + def foo(): + return free + + class DummyLazyWrappingContext(LazyWrappingContext): + def __init__(self, f): + super().__init__(f) + + self.count = 0 + + def __enter__(self): + self.count += 1 + return super().__enter__() + + (wc := DummyLazyWrappingContext(foo)).wrap() + + assert not DummyLazyWrappingContext.is_wrapped(foo) + + for _ in range(n := 10): + assert foo() == free + + assert DummyLazyWrappingContext.is_wrapped(foo) + + assert wc.count == n + + wc.count = 0 + + wc.unwrap() + + for _ in range(10): + assert not DummyLazyWrappingContext.is_wrapped(foo) + + assert foo() == free + + assert wc.count == 0