Skip to content

Commit

Permalink
fix(grpc): ensure interceptor args are correctly parsed [backport 2.8] (
Browse files Browse the repository at this point in the history
#9006)

Backport 72b3aa0 from #8707 to 2.8.

Currently the grpcaio integration assumes interceptors are passed to
`grpc.aio.insecure_channel(.....)` and `grpc.aio.secure_channel(....)`
as keyword arguments. When these methods are called with positional
arguments a segfault is occurs.

This PR uses `set_argument_value` and `get_argument_value` helpers to
ensure grpc channel arguments are correctly parsed and set.

Resolves:  #8648

## Checklist

- [x] Change(s) are motivated and described in the PR description
- [x] Testing strategy is described if automated tests are not included
in the PR
- [x] Risks are described (performance impact, potential for breakage,
maintainability)
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] [Library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
are followed or label `changelog/no-changelog` is set
- [x] Documentation is included (in-code, generated user docs, [public
corp docs](https://github.com/DataDog/documentation/))
- [x] Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))
- [x] If this PR changes the public interface, I've notified
`@DataDog/apm-tees`.
- [x] If change touches code that signs or publishes builds or packages,
or handles credentials of any kind, I've requested a review from
`@DataDog/security-design-and-guidance`.

## Reviewer Checklist

- [x] Title is accurate
- [x] All changes are related to the pull request's stated goal
- [x] Description motivates each change
- [x] Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- [x] Testing strategy adequately addresses listed risks
- [x] Change is maintainable (easy to change, telemetry, documentation)
- [x] Release note makes sense to a user of the library
- [x] Author has acknowledged and discussed the performance implications
of this PR as reported in the benchmarks PR comment
- [x] Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)

Co-authored-by: Munir Abdinur <munir.abdinur@datadoghq.com>
  • Loading branch information
github-actions[bot] and mabdinur committed Apr 22, 2024
1 parent 3ef5436 commit 274535c
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 7 deletions.
14 changes: 12 additions & 2 deletions ddtrace/contrib/grpc/aio_client_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def create_aio_client_interceptors(pin, host, port):
)


def _handle_add_callback(call, callback):
try:
call.add_done_callback(callback)
except NotImplementedError:
# add_done_callback is not implemented in UnaryUnaryCallResponse
# https://github.com/grpc/grpc/blob/c54c69dcdd483eba78ed8dbc98c60a8c2d069758/src/python/grpcio/grpc/aio/_interceptor.py#L1058
# If callback is not called, we need to finish the span here
callback(call)


def _done_callback(span, code, details):
# type: (Span, grpc.StatusCode, str) -> Callable[[aio.Call], None]
def func(call):
Expand Down Expand Up @@ -156,7 +166,7 @@ async def _wrap_stream_response(
details = await call.details()
# NOTE: The callback is registered after the iteration is done,
# otherwise `call.code()` and `call.details()` block indefinitely.
call.add_done_callback(_done_callback(span, code, details))
_handle_add_callback(call, _done_callback(span, code, details))
except aio.AioRpcError as rpc_error:
# NOTE: We can also handle the error in done callbacks,
# but reuse this error handling function used in unary response RPCs.
Expand All @@ -182,7 +192,7 @@ async def _wrap_unary_response(
# NOTE: As both `code` and `details` are available after the RPC is done (= we get `call` object),
# and we can't call awaitable functions inside the non-async callback,
# there is no other way but to register the callback here.
call.add_done_callback(_done_callback(span, code, details))
_handle_add_callback(call, _done_callback(span, code, details))
return call
except aio.AioRpcError as rpc_error:
# NOTE: `AioRpcError` is raised in `await continuation(...)`
Expand Down
16 changes: 12 additions & 4 deletions ddtrace/contrib/grpc/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from ddtrace.internal.schema import schematize_service_name
from ddtrace.vendor.wrapt import wrap_function_wrapper as _w

from ...internal.utils import get_argument_value
from ...internal.utils import set_argument_value
from ..trace_utils import unwrap as _u
from . import constants
from . import utils
Expand Down Expand Up @@ -215,12 +217,18 @@ def _aio_client_channel_interceptor(wrapped, instance, args, kwargs):

(host, port) = utils._parse_target_from_args(args, kwargs)

interceptors = create_aio_client_interceptors(pin, host, port)
dd_interceptors = create_aio_client_interceptors(pin, host, port)
interceptor_index = 3
if wrapped.__name__ == "secure_channel":
interceptor_index = 4
interceptors = get_argument_value(args, kwargs, interceptor_index, "interceptors", True)
# DEV: Inject our tracing interceptor first in the list of interceptors
if "interceptors" in kwargs:
kwargs["interceptors"] = interceptors + tuple(kwargs["interceptors"])
if interceptors:
args, kwargs = set_argument_value(
args, kwargs, interceptor_index, "interceptors", dd_interceptors + tuple(interceptors)
)
else:
kwargs["interceptors"] = interceptors
args, kwargs = set_argument_value(args, kwargs, interceptor_index, "interceptors", dd_interceptors, True)

return wrapped(*args, **kwargs)

Expand Down
3 changes: 2 additions & 1 deletion ddtrace/internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def set_argument_value(
pos, # type: int
kw, # type: str
value, # type: Any
override_unset=False, # type: bool
):
# type: (...) -> Tuple[Tuple[Any, ...], Dict[str, Any]]
"""
Expand All @@ -64,7 +65,7 @@ def set_argument_value(
"""
if len(args) > pos:
args = args[:pos] + (value,) + args[pos + 1 :]
elif kw in kwargs:
elif kw in kwargs or override_unset:
kwargs[kw] = value
else:
raise ArgumentError("%s (at position %d) is invalid" % (kw, pos))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
grpc: Resolves segfaults raised when grpc.aio interceptors are registered
26 changes: 26 additions & 0 deletions tests/contrib/grpc_aio/test_grpc_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ def SayHelloRepeatedly(self, request_iterator, context):
yield HelloReply(message="Good bye")


class DummyClientInterceptor(aio.UnaryUnaryClientInterceptor):
async def intercept_unary_unary(self, continuation, client_call_details, request):
undone_call = await continuation(client_call_details, request)
return await undone_call

def add_done_callback(self, unused_callback):
pass


@pytest.fixture(autouse=True)
def patch_grpc_aio():
patch()
Expand Down Expand Up @@ -266,6 +275,23 @@ async def test_secure_channel(server_info, tracer):
_check_server_span(server_span, "grpc-aio-server", "SayHello", "unary")


@pytest.mark.asyncio
@pytest.mark.parametrize("server_info", [_CoroHelloServicer(), _SyncHelloServicer()], indirect=True)
async def test_secure_channel_with_interceptor_in_args(server_info, tracer):
credentials = grpc.ChannelCredentials(None)
interceptors = [DummyClientInterceptor()]
async with aio.secure_channel(server_info.target, credentials, None, None, interceptors) as channel:
stub = HelloStub(channel)
await stub.SayHello(HelloRequest(name="test"))

spans = _get_spans(tracer)
assert len(spans) == 2
client_span, server_span = spans

_check_client_span(client_span, "grpc-aio-client", "SayHello", "unary")
_check_server_span(server_span, "grpc-aio-server", "SayHello", "unary")


@pytest.mark.asyncio
@pytest.mark.parametrize("server_info", [_CoroHelloServicer(), _SyncHelloServicer()], indirect=True)
async def test_invalid_target(server_info, tracer):
Expand Down

0 comments on commit 274535c

Please sign in to comment.