diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi index 9530a47f389a1..8a38e4fa47a98 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi @@ -59,7 +59,7 @@ cdef class _AioCall: else: call._waiter_call.set_result(None) - async def unary_unary(self, method, request): + async def unary_unary(self, method, request, metadata=None): cdef grpc_call * call cdef grpc_slice method_slice cdef grpc_op * ops @@ -94,7 +94,7 @@ cdef class _AioCall: ops = gpr_malloc(sizeof(grpc_op) * _OP_ARRAY_LENGTH) - initial_metadata_operation = SendInitialMetadataOperation(_EMPTY_METADATA, GRPC_INITIAL_METADATA_USED_MASK) + initial_metadata_operation = SendInitialMetadataOperation(metadata, GRPC_INITIAL_METADATA_USED_MASK) initial_metadata_operation.c() ops[0] = initial_metadata_operation.c_op @@ -146,4 +146,5 @@ cdef class _AioCall: grpc_call_unref(call) gpr_free(ops) - return receive_message_operation.message() + return (receive_initial_metadata_operation, receive_message_operation, + receive_status_on_client_operation) diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi index b52c070553da3..e466516750ed5 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi +++ b/src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi @@ -25,6 +25,6 @@ cdef class AioChannel: def close(self): grpc_channel_destroy(self.channel) - async def unary_unary(self, method, request): + async def unary_unary(self, method, request, metadata=None): call = _AioCall(self) - return await call.unary_unary(method, request) + return await call.unary_unary(method, request, metadata) diff --git a/src/python/grpcio/grpc/experimental/aio/_channel.py b/src/python/grpcio/grpc/experimental/aio/_channel.py index e3c8fcdbf2f4b..0c24acf37be2a 100644 --- a/src/python/grpcio/grpc/experimental/aio/_channel.py +++ b/src/python/grpcio/grpc/experimental/aio/_channel.py @@ -18,6 +18,29 @@ from grpc.experimental import aio +class _RPCState: + + def __init__(self, initial_metadata, trailing_metadata): + self.initial_metadata = initial_metadata + self.response = None + self.trailing_metadata = trailing_metadata + + +def _handle_call_result(operations, state, response_deserializer): + for operation in operations: + operation_type = operation.type() + if operation_type == cygrpc.OperationType.receive_initial_metadata: + state.initial_metadata = operation.initial_metadata() + elif operation_type == cygrpc.OperationType.receive_message: + serialized_response = operation.message() + if serialized_response is not None: + response = _common.deserialize(serialized_response, + response_deserializer) + state.response = response + elif operation_type == cygrpc.OperationType.receive_status_on_client: + state.trailing_metadata = operation.trailing_metadata() + + class UnaryUnaryMultiCallable(aio.UnaryUnaryMultiCallable): def __init__(self, channel, method, request_serializer, @@ -38,8 +61,33 @@ async def __call__(self, if timeout: raise NotImplementedError("TODO: timeout not implemented yet") - if metadata: - raise NotImplementedError("TODO: metadata not implemented yet") + if credentials: + raise NotImplementedError("TODO: credentials not implemented yet") + + if wait_for_ready: + raise NotImplementedError( + "TODO: wait_for_ready not implemented yet") + + if compression: + raise NotImplementedError("TODO: compression not implemented yet") + + state = _RPCState(None, None) + ops = await self._channel.unary_unary( + self._method, _common.serialize(request, self._request_serializer), + metadata) + _handle_call_result(ops, state, self._response_deserializer) + + return state.response + + async def with_state(self, + request, + timeout=None, + metadata=None, + credentials=None, + wait_for_ready=None, + compression=None): + if timeout: + raise NotImplementedError("TODO: timeout not implemented yet") if credentials: raise NotImplementedError("TODO: credentials not implemented yet") @@ -51,10 +99,13 @@ async def __call__(self, if compression: raise NotImplementedError("TODO: compression not implemented yet") - response = await self._channel.unary_unary( - self._method, _common.serialize(request, self._request_serializer)) + state = _RPCState(None, None) + ops = await self._channel.unary_unary( + self._method, _common.serialize(request, self._request_serializer), + metadata) + _handle_call_result(ops, state, self._response_deserializer) - return _common.deserialize(response, self._response_deserializer) + return state class Channel(aio.Channel): diff --git a/src/python/grpcio_tests/tests_aio/tests.json b/src/python/grpcio_tests/tests_aio/tests.json index 49d025a5abeb1..01757b55e9ae4 100644 --- a/src/python/grpcio_tests/tests_aio/tests.json +++ b/src/python/grpcio_tests/tests_aio/tests.json @@ -1,5 +1,6 @@ [ "_sanity._sanity_test.AioSanityTest", "unit.channel_test.TestChannel", - "unit.init_test.TestInsecureChannel" + "unit.init_test.TestInsecureChannel", + "unit.metadata_test.TestMetadata" ] diff --git a/src/python/grpcio_tests/tests_aio/unit/metadata_test.py b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py new file mode 100644 index 0000000000000..1e54303daf9bb --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/metadata_test.py @@ -0,0 +1,64 @@ +# Copyright 2019 The gRPC Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import unittest + +from grpc.experimental import aio +from tests_aio.unit import test_base +from tests_aio.unit import test_common +from src.proto.grpc.testing import messages_pb2 + + +_INVOCATION_METADATA = ( + ( + 'initial-md-key', + 'initial-md-value', + ), + ( + 'trailing-md-key-bin', + b'\x00\x02', + ), +) + + +class TestMetadata(test_base.AioTestBase): + + def test_unary_unary(self): + async def coro(): + channel = aio.insecure_channel(self.server_target) + hi = channel.unary_unary( + '/grpc.testing.TestService/UnaryCall', + request_serializer=messages_pb2.SimpleRequest.SerializeToString, + response_deserializer=messages_pb2.SimpleResponse.FromString) + state = await hi.with_state(messages_pb2.SimpleRequest(), + metadata=_INVOCATION_METADATA) + + self.assertEqual(type(state), aio._channel._RPCState) + self.assertEqual(type(state.response), messages_pb2.SimpleResponse) + self.assertTrue( + test_common.metadata_transmitted((_INVOCATION_METADATA[0],), + state.initial_metadata)) + self.assertTrue( + test_common.metadata_transmitted((_INVOCATION_METADATA[1],), + state.trailing_metadata)) + + await channel.close() + + self.loop.run_until_complete(coro()) + + +if __name__ == '__main__': + logging.basicConfig() + unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests_aio/unit/sync_server.py b/src/python/grpcio_tests/tests_aio/unit/sync_server.py index 105ded8e76ca8..51c3487d96849 100644 --- a/src/python/grpcio_tests/tests_aio/unit/sync_server.py +++ b/src/python/grpcio_tests/tests_aio/unit/sync_server.py @@ -15,18 +15,35 @@ import argparse from concurrent import futures -from time import sleep import grpc from src.proto.grpc.testing import messages_pb2 from src.proto.grpc.testing import test_pb2_grpc +_INITIAL_METADATA_KEY = "initial-md-key" +_TRAILING_METADATA_KEY = "trailing-md-key-bin" + + +def _maybe_echo_metadata(servicer_context): + """Copies metadata from request to response if it is present.""" + invocation_metadata = dict(servicer_context.invocation_metadata()) + if _INITIAL_METADATA_KEY in invocation_metadata: + initial_metadatum = (_INITIAL_METADATA_KEY, + invocation_metadata[_INITIAL_METADATA_KEY]) + servicer_context.send_initial_metadata((initial_metadatum,)) + if _TRAILING_METADATA_KEY in invocation_metadata: + trailing_metadatum = (_TRAILING_METADATA_KEY, + invocation_metadata[_TRAILING_METADATA_KEY]) + servicer_context.set_trailing_metadata((trailing_metadatum,)) + + # TODO (https://github.com/grpc/grpc/issues/19762) # Change for an asynchronous server version once it's implemented. class TestServiceServicer(test_pb2_grpc.TestServiceServicer): def UnaryCall(self, request, context): + _maybe_echo_metadata(context) return messages_pb2.SimpleResponse() diff --git a/src/python/grpcio_tests/tests_aio/unit/test_common.py b/src/python/grpcio_tests/tests_aio/unit/test_common.py new file mode 100644 index 0000000000000..063e86206bd8e --- /dev/null +++ b/src/python/grpcio_tests/tests_aio/unit/test_common.py @@ -0,0 +1,76 @@ +# Copyright 2019 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common code used throughout aio tests of gRPC.""" + +import collections + +import six + +INVOCATION_INITIAL_METADATA = ( + ('0', 'abc'), + ('1', 'def'), + ('2', 'ghi'), +) +SERVICE_INITIAL_METADATA = ( + ('3', 'jkl'), + ('4', 'mno'), + ('5', 'pqr'), +) +SERVICE_TERMINAL_METADATA = ( + ('6', 'stu'), + ('7', 'vwx'), + ('8', 'yza'), +) +DETAILS = 'test details' + + +def metadata_transmitted(original_metadata, transmitted_metadata): + """Judges whether or not metadata was acceptably transmitted. + + gRPC is allowed to insert key-value pairs into the metadata values given by + applications and to reorder key-value pairs with different keys but it is not + allowed to alter existing key-value pairs or to reorder key-value pairs with + the same key. + + Args: + original_metadata: A metadata value used in a test of gRPC. An iterable over + iterables of length 2. + transmitted_metadata: A metadata value corresponding to original_metadata + after having been transmitted via gRPC. An iterable over iterables of + length 2. + + Returns: + A boolean indicating whether transmitted_metadata accurately reflects + original_metadata after having been transmitted via gRPC. + """ + original = collections.defaultdict(list) + for key, value in original_metadata: + original[key].append(value) + transmitted = collections.defaultdict(list) + for key, value in transmitted_metadata: + transmitted[key].append(value) + + for key, values in six.iteritems(original): + transmitted_values = transmitted[key] + transmitted_iterator = iter(transmitted_values) + try: + for value in values: + while True: + transmitted_value = next(transmitted_iterator) + if value == transmitted_value: + break + except StopIteration: + return False + else: + return True