diff --git a/sdk/core/azure-core/azure/core/tracing/decorator.py b/sdk/core/azure-core/azure/core/tracing/decorator.py index 686dd9d7b7628..7330af62ee31c 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator.py @@ -27,6 +27,8 @@ import functools +from typing import overload + from .common import change_context, get_function_and_class_name from ..settings import settings @@ -36,11 +38,27 @@ TYPE_CHECKING = False if TYPE_CHECKING: - from typing import Callable, Dict, Optional, Any, cast + from typing import Callable, Dict, Optional, Any, TypeVar + + T = TypeVar("T") + + +@overload +def distributed_trace(__func): + # type: (Callable[..., T]) -> Callable[..., T] + pass -def distributed_trace(_func=None, name_of_span=None, **kwargs): - # type: (Callable, Optional[str], Optional[Dict[str, Any]]) -> Callable +@overload +def distributed_trace(**kwargs): # pylint:disable=function-redefined,unused-argument + # type: (**Any) -> Callable[[Callable[..., T]], Callable[..., T]] + pass + + +def distributed_trace( # pylint:disable=function-redefined + __func=None, # type: Callable[..., T] + **kwargs # type: Any +): """Decorator to apply to function to get traced automatically. Span will use the func name or "name_of_span". @@ -48,37 +66,33 @@ def distributed_trace(_func=None, name_of_span=None, **kwargs): :param callable func: A function to decorate :param str name_of_span: The span name to replace func name if necessary """ - tracing_attributes = kwargs.get('tracing_attributes') - # https://github.com/python/mypy/issues/2608 - if _func is None: - return functools.partial( - distributed_trace, - name_of_span=name_of_span, - tracing_attributes=tracing_attributes, - ) - func = _func # mypy is happy now - - not_none_tracing_attributes = tracing_attributes if tracing_attributes else {} - - @functools.wraps(func) - def wrapper_use_tracer(*args, **kwargs): - # type: (Any, Any) -> Any - merge_span = kwargs.pop("merge_span", False) - passed_in_parent = kwargs.pop("parent_span", None) - - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return func(*args, **kwargs) - - # Merge span is parameter is set, but only if no explicit parent are passed - if merge_span and not passed_in_parent: - return func(*args, **kwargs) - - with change_context(passed_in_parent): - name = name_of_span or get_function_and_class_name(func, *args) - with span_impl_type(name=name) as span: - for key, value in not_none_tracing_attributes.items(): - span.add_attribute(key, value) + name_of_span = kwargs.pop("name_of_span", None) + tracing_attributes = kwargs.pop("tracing_attributes", {}) + + def decorator(func): + # type: (Callable[..., T]) -> Callable[..., T] + + @functools.wraps(func) + def wrapper_use_tracer(*args, **kwargs): + # type: (*Any, **Any) -> T + merge_span = kwargs.pop("merge_span", False) + passed_in_parent = kwargs.pop("parent_span", None) + + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return func(*args, **kwargs) + + # Merge span is parameter is set, but only if no explicit parent are passed + if merge_span and not passed_in_parent: return func(*args, **kwargs) - return wrapper_use_tracer + with change_context(passed_in_parent): + name = name_of_span or get_function_and_class_name(func, *args) + with span_impl_type(name=name) as span: + for key, value in tracing_attributes.items(): + span.add_attribute(key, value) + return func(*args, **kwargs) + + return wrapper_use_tracer + + return decorator if __func is None else decorator(__func) diff --git a/sdk/core/azure-core/azure/core/tracing/decorator_async.py b/sdk/core/azure-core/azure/core/tracing/decorator_async.py index 8a2a0ad6efcef..de051987274bb 100644 --- a/sdk/core/azure-core/azure/core/tracing/decorator_async.py +++ b/sdk/core/azure-core/azure/core/tracing/decorator_async.py @@ -27,57 +27,64 @@ import functools +from typing import Awaitable, Callable, Dict, Optional, Any, TypeVar, overload + from .common import change_context, get_function_and_class_name from ..settings import settings -try: - from typing import TYPE_CHECKING -except ImportError: - TYPE_CHECKING = False -if TYPE_CHECKING: - from typing import Callable, Dict, Optional, Any +T = TypeVar("T") + + +@overload +def distributed_trace_async( + __func: Callable[..., Awaitable[T]] +) -> Callable[..., Awaitable[T]]: + pass -def distributed_trace_async(_func=None, name_of_span=None, *, tracing_attributes=None): - # type: (Callable, Optional[str], Optional[Dict[str, Any]]) -> Callable - """Decorator to apply to async function to get traced automatically. +@overload +def distributed_trace_async( # pylint:disable=function-redefined + **kwargs: Any # pylint:disable=unused-argument +) -> Callable[[Callable[..., Awaitable[T]]], Callable[..., Awaitable[T]]]: + pass + + +def distributed_trace_async( # pylint:disable=function-redefined + __func: Callable[..., Awaitable[T]] = None, **kwargs: Any +): + """Decorator to apply to function to get traced automatically. Span will use the func name or "name_of_span". :param callable func: A function to decorate :param str name_of_span: The span name to replace func name if necessary """ - # https://github.com/python/mypy/issues/2608 - if _func is None: - return functools.partial( - distributed_trace_async, - name_of_span=name_of_span, - tracing_attributes=tracing_attributes, - ) - func = _func # mypy is happy now - - not_none_tracing_attributes = tracing_attributes if tracing_attributes else {} - - @functools.wraps(func) - async def wrapper_use_tracer(*args, **kwargs): - # type: (Any, Any) -> Any - merge_span = kwargs.pop("merge_span", False) - passed_in_parent = kwargs.pop("parent_span", None) - - span_impl_type = settings.tracing_implementation() - if span_impl_type is None: - return await func(*args, **kwargs) - - # Merge span is parameter is set, but only if no explicit parent are passed - if merge_span and not passed_in_parent: - return await func(*args, **kwargs) - - with change_context(passed_in_parent): - name = name_of_span or get_function_and_class_name(func, *args) - with span_impl_type(name=name) as span: - for key, value in not_none_tracing_attributes.items(): - span.add_attribute(key, value) + name_of_span = kwargs.pop("name_of_span", None) + tracing_attributes = kwargs.pop("tracing_attributes", {}) + + def decorator(func: Callable[..., Awaitable[T]]) -> Callable[..., Awaitable[T]]: + @functools.wraps(func) + async def wrapper_use_tracer(*args, **kwargs): + # type: (*Any, **Any) -> T + merge_span = kwargs.pop("merge_span", False) + passed_in_parent = kwargs.pop("parent_span", None) + + span_impl_type = settings.tracing_implementation() + if span_impl_type is None: + return await func(*args, **kwargs) + + # Merge span is parameter is set, but only if no explicit parent are passed + if merge_span and not passed_in_parent: return await func(*args, **kwargs) - return wrapper_use_tracer + with change_context(passed_in_parent): + name = name_of_span or get_function_and_class_name(func, *args) + with span_impl_type(name=name) as span: + for key, value in tracing_attributes.items(): + span.add_attribute(key, value) + return await func(*args, **kwargs) + + return wrapper_use_tracer + + return decorator if __func is None else decorator(__func)