diff --git a/ddtrace/profiling/_nogevent.py b/ddtrace/profiling/_nogevent.py new file mode 100644 index 00000000000..888b8ed845c --- /dev/null +++ b/ddtrace/profiling/_nogevent.py @@ -0,0 +1,67 @@ +# -*- encoding: utf-8 -*- +"""This files exposes non-gevent Python original functions.""" +import threading + +from ddtrace.vendor import six +from ddtrace.vendor import attr + + +try: + import gevent.monkey +except ImportError: + + def get_original(module, func): + return getattr(__import__(module), func) + + def is_module_patched(module): + return False + + +else: + get_original = gevent.monkey.get_original + is_module_patched = gevent.monkey.is_module_patched + + +sleep = get_original("time", "sleep") + +try: + # Python ≥ 3.8 + threading_get_native_id = get_original("threading", "get_native_id") +except AttributeError: + threading_get_native_id = None + +start_new_thread = get_original(six.moves._thread.__name__, "start_new_thread") +thread_get_ident = get_original(six.moves._thread.__name__, "get_ident") +Thread = get_original("threading", "Thread") +Lock = get_original("threading", "Lock") + + +if is_module_patched("threading"): + + @attr.s + class DoubleLock(object): + """A lock that prevent concurrency from a gevent coroutine and from a threading.Thread at the same time.""" + + _lock = attr.ib(factory=threading.Lock, init=False, repr=False) + _thread_lock = attr.ib(factory=Lock, init=False, repr=False) + + def acquire(self): + # You cannot acquire a gevent-lock from another thread if it has been acquired already: + # make sure we exclude the gevent-lock from being acquire by another thread by using a thread-lock first. + self._thread_lock.acquire() + self._lock.acquire() + + def release(self): + self._lock.release() + self._thread_lock.release() + + def __enter__(self): + self.acquire() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.release() + + +else: + DoubleLock = threading.Lock diff --git a/ddtrace/profiling/_periodic.py b/ddtrace/profiling/_periodic.py index e1fece5be3d..3c99267ee17 100644 --- a/ddtrace/profiling/_periodic.py +++ b/ddtrace/profiling/_periodic.py @@ -3,8 +3,8 @@ import threading from ddtrace.profiling import _service +from ddtrace.profiling import _nogevent from ddtrace.vendor import attr -from ddtrace.vendor import six PERIODIC_THREAD_IDS = set() @@ -70,14 +70,6 @@ def __init__(self, interval, target, name=None, on_shutdown=None): :param on_shutdown: The function to call when the thread shuts down. """ super(_GeventPeriodicThread, self).__init__(interval, target, name, on_shutdown) - import gevent.monkey - - self._sleep = gevent.monkey.get_original("time", "sleep") - try: - # Python ≥ 3.8 - self._get_native_id = gevent.monkey.get_original("threading", "get_native_id") - except AttributeError: - self._get_native_id = None self._tident = None @property @@ -86,24 +78,20 @@ def ident(self): def start(self): """Start the thread.""" - import gevent.monkey - - start_new_thread = gevent.monkey.get_original(six.moves._thread.__name__, "start_new_thread") - self.quit = False self.has_quit = False threading._limbo[self] = self try: - self._tident = start_new_thread(self.run, tuple()) + self._tident = _nogevent.start_new_thread(self.run, tuple()) except Exception: del threading._limbo[self] - if self._get_native_id: - self._native_id = self._get_native_id() + if _nogevent.threading_get_native_id: + self._native_id = _nogevent.threading_get_native_id() def join(self, timeout=None): # FIXME: handle the timeout argument while not self.has_quit: - self._sleep(self.SLEEP_INTERVAL) + _nogevent.sleep(self.SLEEP_INTERVAL) def stop(self): """Stop the thread.""" @@ -121,7 +109,7 @@ def run(self): self._target() slept = 0 while self.quit is False and slept < self.interval: - self._sleep(self.SLEEP_INTERVAL) + _nogevent.sleep(self.SLEEP_INTERVAL) slept += self.SLEEP_INTERVAL if self._on_shutdown is not None: self._on_shutdown() @@ -151,11 +139,8 @@ def PeriodicRealThread(*args, **kwargs): in e.g. the gevent case, where Lock object must not be shared with the MainThread (otherwise it'd dead lock). """ - if "gevent" in sys.modules: - import gevent.monkey - - if gevent.monkey.is_module_patched("threading"): - return _GeventPeriodicThread(*args, **kwargs) + if _nogevent.is_module_patched("threading"): + return _GeventPeriodicThread(*args, **kwargs) return PeriodicThread(*args, **kwargs) diff --git a/ddtrace/profiling/_service.py b/ddtrace/profiling/_service.py index 1a0173598a7..36e8f5838aa 100644 --- a/ddtrace/profiling/_service.py +++ b/ddtrace/profiling/_service.py @@ -27,7 +27,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - return self.stop() + self.stop() + self.join() def start(self): """Start the service.""" diff --git a/ddtrace/profiling/collector/stack.pyx b/ddtrace/profiling/collector/stack.pyx index 0da3a8d3fda..8d77ad55f69 100644 --- a/ddtrace/profiling/collector/stack.pyx +++ b/ddtrace/profiling/collector/stack.pyx @@ -10,6 +10,7 @@ import weakref from ddtrace import compat from ddtrace.profiling import _attr from ddtrace.profiling import _periodic +from ddtrace.profiling import _nogevent from ddtrace.profiling import collector from ddtrace.profiling import event from ddtrace.profiling.collector import _traceback @@ -21,22 +22,12 @@ from ddtrace.vendor import six _LOG = logging.getLogger(__name__) -if "gevent" in sys.modules: - try: - import gevent.monkey - except ImportError: - _LOG.error("gevent loaded but unable to import gevent.monkey") - from threading import Lock as _threading_Lock - from ddtrace.vendor.six.moves._thread import get_ident as _thread_get_ident - else: - _threading_Lock = gevent.monkey.get_original("threading", "Lock") - _thread_get_ident = gevent.monkey.get_original("thread" if six.PY2 else "_thread", "get_ident") - +if _nogevent.is_module_patched("threading"): # NOTE: bold assumption: this module is always imported by the MainThread. # The python `threading` module makes that assumption and it's beautiful we're going to do the same. - _main_thread_id = _thread_get_ident() + # We don't have the choice has we can't access the original MainThread + _main_thread_id = _nogevent.thread_get_ident() else: - from threading import Lock as _threading_Lock from ddtrace.vendor.six.moves._thread import get_ident as _thread_get_ident if six.PY2: _main_thread_id = threading._MainThread().ident @@ -388,7 +379,7 @@ class _ThreadSpanLinks(object): # Keys is a thread_id # Value is a set of weakrefs to spans _thread_id_to_spans = attr.ib(factory=lambda: collections.defaultdict(set), repr=False, init=False) - _lock = attr.ib(factory=_threading_Lock, repr=False, init=False) + _lock = attr.ib(factory=_nogevent.Lock, repr=False, init=False) def link_span(self, span): """Link a span to its running environment. @@ -397,7 +388,7 @@ class _ThreadSpanLinks(object): """ # Since we're going to iterate over the set, make sure it's locked with self._lock: - self._thread_id_to_spans[_thread_get_ident()].add(weakref.ref(span)) + self._thread_id_to_spans[_nogevent.thread_get_ident()].add(weakref.ref(span)) def clear_threads(self, existing_thread_ids): """Clear the stored list of threads based on the list of existing thread ids. diff --git a/ddtrace/profiling/recorder.py b/ddtrace/profiling/recorder.py index c8bd96a1332..15b73147684 100644 --- a/ddtrace/profiling/recorder.py +++ b/ddtrace/profiling/recorder.py @@ -1,6 +1,7 @@ # -*- encoding: utf-8 -*- import collections +from ddtrace.profiling import _nogevent from ddtrace.vendor import attr @@ -30,6 +31,7 @@ class Recorder(object): """A dict of {event_type_class: max events} to limit the number of events to record.""" events = attr.ib(init=False, repr=False) + _events_lock = attr.ib(init=False, repr=False, factory=_nogevent.DoubleLock) def __attrs_post_init__(self): self._reset_events() @@ -51,8 +53,9 @@ def push_events(self, events): """ if events: event_type = events[0].__class__ - q = self.events[event_type] - q.extend(events) + with self._events_lock: + q = self.events[event_type] + q.extend(events) def _get_deque_for_event_type(self, event_type): return collections.deque(maxlen=self.max_events.get(event_type, self.default_max_events)) @@ -68,6 +71,7 @@ def reset(self): :return: The list of events that has been removed. """ - events = self.events - self._reset_events() + with self._events_lock: + events = self.events + self._reset_events() return events diff --git a/tests/profiling/collector/test_stack.py b/tests/profiling/collector/test_stack.py index ccb88cc491c..5fe029b971e 100644 --- a/tests/profiling/collector/test_stack.py +++ b/tests/profiling/collector/test_stack.py @@ -9,6 +9,7 @@ import ddtrace from ddtrace.vendor import six +from ddtrace.profiling import _nogevent from ddtrace.profiling import recorder from ddtrace.profiling.collector import stack @@ -16,17 +17,6 @@ TESTING_GEVENT = os.getenv("DD_PROFILE_TEST_GEVENT", False) -try: - from gevent import monkey -except ImportError: - real_sleep = time.sleep - real_Thread = threading.Thread - real_Lock = threading.Lock -else: - real_sleep = monkey.get_original("time", "sleep") - real_Thread = monkey.get_original("threading", "Thread") - real_Lock = monkey.get_original("threading", "Lock") - def func1(): return func2() @@ -45,7 +35,7 @@ def func4(): def func5(): - return real_sleep(1) + return _nogevent.sleep(1) def test_collect_truncate(): @@ -192,6 +182,25 @@ def test_stress_threads(): t.join() +def test_stress_threads_run_as_thread(): + NB_THREADS = 40 + + threads = [] + for i in range(NB_THREADS): + t = threading.Thread(target=_f0) # noqa: E149,F821 + t.start() + threads.append(t) + + r = recorder.Recorder() + s = stack.StackCollector(recorder=r) + # This mainly check nothing bad happens when we collect a lot of threads and store the result in the Recorder + with s: + time.sleep(3) + assert r.events[stack.StackSampleEvent] + for t in threads: + t.join() + + @pytest.mark.skipif(not stack.FEATURES["stack-exceptions"], reason="Stack exceptions not supported") @pytest.mark.skipif(TESTING_GEVENT, reason="Test not compatible with gevent") def test_exception_collection_threads(): @@ -221,21 +230,20 @@ def test_exception_collection_threads(): def test_exception_collection(): r = recorder.Recorder() c = stack.StackCollector(r) - c.start() - try: - raise ValueError("hello") - except Exception: - real_sleep(1) - c.stop() + with c: + try: + raise ValueError("hello") + except Exception: + _nogevent.sleep(1) exception_events = r.events[stack.StackExceptionSampleEvent] assert len(exception_events) >= 1 e = exception_events[0] assert e.timestamp > 0 assert e.sampling_period > 0 - assert e.thread_id == stack._thread_get_ident() + assert e.thread_id == _nogevent.thread_get_ident() assert e.thread_name == "MainThread" - assert e.frames == [(__file__, 228, "test_exception_collection")] + assert e.frames == [(__file__, 237, "test_exception_collection")] assert e.nframes == 1 assert e.exc_type == ValueError @@ -255,7 +263,7 @@ def tracer_and_collector(): def test_thread_to_span_thread_isolation(tracer_and_collector): t, c = tracer_and_collector root = t.start_span("root") - thread_id = stack._thread_get_ident() + thread_id = _nogevent.thread_get_ident() assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == {root} store = {} @@ -278,7 +286,7 @@ def start_span(): def test_thread_to_span_multiple(tracer_and_collector): t, c = tracer_and_collector root = t.start_span("root") - thread_id = stack._thread_get_ident() + thread_id = _nogevent.thread_get_ident() assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == {root} subspan = t.start_span("subtrace", child_of=root) assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == {subspan} @@ -297,7 +305,7 @@ def test_thread_to_child_span_multiple_unknown_thread(tracer_and_collector): def test_thread_to_child_span_clear(tracer_and_collector): t, c = tracer_and_collector root = t.start_span("root") - thread_id = stack._thread_get_ident() + thread_id = _nogevent.thread_get_ident() assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == {root} c._thread_span_links.clear_threads(set()) assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == set() @@ -306,7 +314,7 @@ def test_thread_to_child_span_clear(tracer_and_collector): def test_thread_to_child_span_multiple_more_children(tracer_and_collector): t, c = tracer_and_collector root = t.start_span("root") - thread_id = stack._thread_get_ident() + thread_id = _nogevent.thread_get_ident() assert c._thread_span_links.get_active_leaf_spans_from_thread_id(thread_id) == {root} subspan = t.start_span("subtrace", child_of=root) subsubspan = t.start_span("subsubtrace", child_of=subspan) @@ -374,10 +382,10 @@ def _trace(): def test_thread_time_cache(): tt = stack._ThreadTime() - lock = real_Lock() + lock = _nogevent.Lock() lock.acquire() - t = real_Thread(target=lock.acquire) + t = _nogevent.Thread(target=lock.acquire) t.start() main_thread_id = threading.current_thread().ident