diff --git a/ddtrace/context.py b/ddtrace/context.py index 89079bc718b..16b30995ae7 100644 --- a/ddtrace/context.py +++ b/ddtrace/context.py @@ -20,7 +20,7 @@ class Context(object): This data structure is thread-safe. """ - def __init__(self, trace_id=None, span_id=None): + def __init__(self, trace_id=None, span_id=None, sampled=True): """ Initialize a new thread-safe ``Context``. @@ -28,7 +28,7 @@ def __init__(self, trace_id=None, span_id=None): :param int span_id: span_id of parent span """ self._trace = [] - self._sampled = False + self._sampled = sampled self._finished_spans = 0 self._current_span = None self._lock = threading.Lock() diff --git a/ddtrace/contrib/asyncio/helpers.py b/ddtrace/contrib/asyncio/helpers.py index a687426cf7e..b0842663fe5 100644 --- a/ddtrace/contrib/asyncio/helpers.py +++ b/ddtrace/contrib/asyncio/helpers.py @@ -79,8 +79,11 @@ def _wrap_executor(fn, args, tracer, ctx): def create_task(*args, **kwargs): - """ This method will enable spawned tasks to parent to the base task context """ - return _wrapped_create_task(_orig_create_task, None, args, kwargs) + """This function spawns a task with a Context that inherits the + `trace_id` and the `parent_id` from the current active one if available. + """ + loop = asyncio.get_event_loop() + return _wrapped_create_task(loop.create_task, None, args, kwargs) def _wrapped_create_task(wrapped, instance, args, kwargs): diff --git a/ddtrace/contrib/asyncio/patch.py b/ddtrace/contrib/asyncio/patch.py index 9642341370a..040b67733f4 100644 --- a/ddtrace/contrib/asyncio/patch.py +++ b/ddtrace/contrib/asyncio/patch.py @@ -1,39 +1,27 @@ -# project -import ddtrace -from ddtrace.util import unwrap -from ddtrace.provider import DefaultContextProvider - -# 3p -import wrapt import asyncio -from .helpers import _wrapped_create_task -from . import context_provider +from wrapt import wrap_function_wrapper as _w -_orig_create_task = asyncio.BaseEventLoop.create_task +from .helpers import _wrapped_create_task +from ...util import unwrap as _u -def patch(tracer=ddtrace.tracer): - """ - Patches `BaseEventLoop.create_task` to enable spawned tasks to parent to - the base task context. Will also enable the asyncio task context. +def patch(): + """Patches current loop `create_task()` method to enable spawned tasks to + parent to the base task context. """ - # TODO: figure what to do with helpers.ensure_future and - # helpers.run_in_executor (doesn't work for ProcessPoolExecutor) if getattr(asyncio, '_datadog_patch', False): return setattr(asyncio, '_datadog_patch', True) - tracer.configure(context_provider=context_provider) - wrapt.wrap_function_wrapper('asyncio', 'BaseEventLoop.create_task', _wrapped_create_task) + loop = asyncio.get_event_loop() + _w(loop, 'create_task', _wrapped_create_task) -def unpatch(tracer=ddtrace.tracer): - """ - Remove tracing from patched modules. - """ +def unpatch(): + """Remove tracing from patched modules.""" + if getattr(asyncio, '_datadog_patch', False): setattr(asyncio, '_datadog_patch', False) - - tracer.configure(context_provider=DefaultContextProvider()) - unwrap(asyncio.BaseEventLoop, 'create_task') + loop = asyncio.get_event_loop() + _u(loop, 'create_task') diff --git a/tests/contrib/asyncio/test_helpers.py b/tests/contrib/asyncio/test_helpers.py index 16d7b6feb4f..dc22943fa27 100644 --- a/tests/contrib/asyncio/test_helpers.py +++ b/tests/contrib/asyncio/test_helpers.py @@ -31,7 +31,7 @@ def future_work(): eq_('coroutine', ctx._trace[0].name) return ctx._trace[0].name - span = self.tracer.trace('coroutine') + self.tracer.trace('coroutine') # schedule future work and wait for a result delayed_task = helpers.ensure_future(future_work(), tracer=self.tracer) result = yield from asyncio.wait_for(delayed_task, timeout=1) @@ -67,3 +67,21 @@ def future_work(): span.finish() result = yield from future ok_(result) + + @mark_asyncio + def test_create_task(self): + # the helper should create a new Task that has the Context attached + @asyncio.coroutine + def future_work(): + # the ctx is available in this task + ctx = self.tracer.get_call_context() + eq_(0, len(ctx._trace)) + child_span = self.tracer.trace('child_task') + return child_span + + root_span = self.tracer.trace('main_task') + # schedule future work and wait for a result + task = helpers.create_task(future_work()) + result = yield from task + eq_(root_span.trace_id, result.trace_id) + eq_(root_span.span_id, result.parent_id) diff --git a/tests/contrib/asyncio/test_tracer.py b/tests/contrib/asyncio/test_tracer.py index dcc432304ba..ccce4fef805 100644 --- a/tests/contrib/asyncio/test_tracer.py +++ b/tests/contrib/asyncio/test_tracer.py @@ -1,23 +1,22 @@ import asyncio + from asyncio import BaseEventLoop from ddtrace.context import Context -from ddtrace.contrib.asyncio.helpers import set_call_context -from ddtrace.contrib.asyncio.patch import patch, unpatch -from ddtrace.contrib.asyncio import context_provider from ddtrace.provider import DefaultContextProvider +from ddtrace.contrib.asyncio.patch import patch, unpatch +from ddtrace.contrib.asyncio.helpers import set_call_context from nose.tools import eq_, ok_ - from .utils import AsyncioTestCase, mark_asyncio + _orig_create_task = BaseEventLoop.create_task class TestAsyncioTracer(AsyncioTestCase): - """ - Ensure that the ``AsyncioTracer`` works for asynchronous execution - within the same ``IOLoop``. + """Ensure that the tracer works with asynchronous executions within + the same ``IOLoop``. """ @mark_asyncio def test_get_call_context(self): @@ -204,92 +203,102 @@ def f1(): span = spans[0] ok_(span.duration > 0.25, msg='span.duration={}'.format(span.duration)) - @mark_asyncio - def test_patch_chain(self): - patch(self.tracer) - - assert self.tracer._context_provider is context_provider - - with self.tracer.trace('foo'): - @self.tracer.wrap('f1') - @asyncio.coroutine - def f1(): - yield from asyncio.sleep(0.1) - - @self.tracer.wrap('f2') - @asyncio.coroutine - def f2(): - yield from asyncio.ensure_future(f1()) - - yield from asyncio.ensure_future(f2()) - - traces = list(reversed(self.tracer.writer.pop_traces())) - assert len(traces) == 3 - root_span = traces[0][0] - last_span_id = None - for trace in traces: - assert len(trace) == 1 - span = trace[0] - assert span.trace_id == root_span.trace_id - assert span.parent_id == last_span_id - last_span_id = span.span_id + +class TestAsyncioPropagation(AsyncioTestCase): + """Ensure that asyncio context propagation works between different tasks""" + def setUp(self): + # patch asyncio event loop + super(TestAsyncioPropagation, self).setUp() + patch() + + def tearDown(self): + # unpatch asyncio event loop + super(TestAsyncioPropagation, self).tearDown() + unpatch() @mark_asyncio - def test_patch_parallel(self): - patch(self.tracer) + def test_tasks_chaining(self): + # ensures that the context is propagated between different tasks + @self.tracer.wrap('spawn_task') + @asyncio.coroutine + def coro_2(): + yield from asyncio.sleep(0.01) + + @self.tracer.wrap('main_task') + @asyncio.coroutine + def coro_1(): + yield from asyncio.ensure_future(coro_2()) - assert self.tracer._context_provider is context_provider + yield from coro_1() - with self.tracer.trace('foo'): - @self.tracer.wrap('f1') - @asyncio.coroutine - def f1(): - yield from asyncio.sleep(0.1) + traces = self.tracer.writer.pop_traces() + eq_(len(traces), 2) + eq_(len(traces[0]), 1) + eq_(len(traces[1]), 1) + spawn_task = traces[0][0] + main_task = traces[1][0] + # check if the context has been correctly propagated + eq_(spawn_task.trace_id, main_task.trace_id) + eq_(spawn_task.parent_id, main_task.span_id) - @self.tracer.wrap('f2') - @asyncio.coroutine - def f2(): - yield from asyncio.sleep(0.1) + @mark_asyncio + def test_concurrent_chaining(self): + # ensures that the context is correctly propagated when + # concurrent tasks are created from a common tracing block + @self.tracer.wrap('f1') + @asyncio.coroutine + def f1(): + yield from asyncio.sleep(0.01) + @self.tracer.wrap('f2') + @asyncio.coroutine + def f2(): + yield from asyncio.sleep(0.01) + + with self.tracer.trace('main_task'): yield from asyncio.gather(f1(), f2()) traces = self.tracer.writer.pop_traces() - assert len(traces) == 3 - root_span = traces[2][0] - for trace in traces[:2]: - assert len(trace) == 1 - span = trace[0] - assert span.trace_id == root_span.trace_id - assert span.parent_id == root_span.span_id + eq_(len(traces), 3) + eq_(len(traces[0]), 1) + eq_(len(traces[1]), 1) + eq_(len(traces[2]), 1) + child_1 = traces[0][0] + child_2 = traces[1][0] + main_task = traces[2][0] + # check if the context has been correctly propagated + eq_(child_1.trace_id, main_task.trace_id) + eq_(child_1.parent_id, main_task.span_id) + eq_(child_2.trace_id, main_task.trace_id) + eq_(child_2.parent_id, main_task.span_id) @mark_asyncio - def test_distributed(self): - patch(self.tracer) - + def test_propagation_with_new_context(self): + # ensures that if a new Context is attached to the current + # running Task, a previous trace is resumed task = asyncio.Task.current_task() ctx = Context(trace_id=100, span_id=101) set_call_context(task, ctx) - with self.tracer.trace('foo'): - pass + with self.tracer.trace('async_task'): + yield from asyncio.sleep(0.01) traces = self.tracer.writer.pop_traces() - assert len(traces) == 1 - trace = traces[0] - assert len(trace) == 1 - span = trace[0] - - assert span.trace_id == ctx._parent_trace_id - assert span.parent_id == ctx._parent_span_id + eq_(len(traces), 1) + eq_(len(traces[0]), 1) + span = traces[0][0] + eq_(span.trace_id, 100) + eq_(span.parent_id, 101) @mark_asyncio - def test_unpatch(self): - patch(self.tracer) - unpatch(self.tracer) - - assert isinstance(self.tracer._context_provider, DefaultContextProvider) - assert BaseEventLoop.create_task == _orig_create_task - - def test_double_patch(self): - patch(self.tracer) - self.test_patch_chain() + def test_event_loop_unpatch(self): + # ensures that the event loop can be unpatched + unpatch() + ok_(isinstance(self.tracer._context_provider, DefaultContextProvider)) + ok_(BaseEventLoop.create_task == _orig_create_task) + + def test_event_loop_double_patch(self): + # ensures that double patching will not double instrument + # the event loop + patch() + self.test_tasks_chaining()