Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] implement metadata for asynchronous unaryunary callable #10

7 changes: 4 additions & 3 deletions src/python/grpcio/grpc/_cython/_cygrpc/aio/call.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ cdef class _AioCall:
else:
call._waiter_call.set_result(None)

async def unary_unary(self, method, request):
async def unary_unary(self, method, request, metadata=None):
cdef grpc_call * call
cdef grpc_slice method_slice
cdef grpc_op * ops
Expand Down Expand Up @@ -94,7 +94,7 @@ cdef class _AioCall:

ops = <grpc_op *>gpr_malloc(sizeof(grpc_op) * _OP_ARRAY_LENGTH)

initial_metadata_operation = SendInitialMetadataOperation(_EMPTY_METADATA, GRPC_INITIAL_METADATA_USED_MASK)
initial_metadata_operation = SendInitialMetadataOperation(metadata, GRPC_INITIAL_METADATA_USED_MASK)
initial_metadata_operation.c()
ops[0] = <grpc_op> initial_metadata_operation.c_op

Expand Down Expand Up @@ -146,4 +146,5 @@ cdef class _AioCall:
grpc_call_unref(call)
gpr_free(ops)

return receive_message_operation.message()
return (receive_initial_metadata_operation, receive_message_operation,
receive_status_on_client_operation)
4 changes: 2 additions & 2 deletions src/python/grpcio/grpc/_cython/_cygrpc/aio/channel.pyx.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ cdef class AioChannel:
def close(self):
grpc_channel_destroy(self.channel)

async def unary_unary(self, method, request):
async def unary_unary(self, method, request, metadata=None):
call = _AioCall(self)
return await call.unary_unary(method, request)
return await call.unary_unary(method, request, metadata)
61 changes: 56 additions & 5 deletions src/python/grpcio/grpc/experimental/aio/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,29 @@
from grpc.experimental import aio


class _RPCState:

def __init__(self, initial_metadata, trailing_metadata):
self.initial_metadata = initial_metadata
self.response = None
self.trailing_metadata = trailing_metadata


def _handle_call_result(operations, state, response_deserializer):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where did you get this logic from? Is it from the sync version? Can you link it? :) Gives the impression we are implementing more things than needed (for now)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got it from here: https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_channel.py#L119

I think we may use this function construct more information in the future.

for operation in operations:
operation_type = operation.type()
if operation_type == cygrpc.OperationType.receive_initial_metadata:
state.initial_metadata = operation.initial_metadata()
elif operation_type == cygrpc.OperationType.receive_message:
serialized_response = operation.message()
if serialized_response is not None:
response = _common.deserialize(serialized_response,
response_deserializer)
state.response = response
elif operation_type == cygrpc.OperationType.receive_status_on_client:
state.trailing_metadata = operation.trailing_metadata()


class UnaryUnaryMultiCallable(aio.UnaryUnaryMultiCallable):

def __init__(self, channel, method, request_serializer,
Expand All @@ -38,8 +61,33 @@ async def __call__(self,
if timeout:
raise NotImplementedError("TODO: timeout not implemented yet")

if metadata:
raise NotImplementedError("TODO: metadata not implemented yet")
if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")

if wait_for_ready:
raise NotImplementedError(
"TODO: wait_for_ready not implemented yet")

if compression:
raise NotImplementedError("TODO: compression not implemented yet")

state = _RPCState(None, None)
ops = await self._channel.unary_unary(
self._method, _common.serialize(request, self._request_serializer),
metadata)
_handle_call_result(ops, state, self._response_deserializer)

return state.response

async def with_state(self,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this method used for? Why did you add it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should use the name of sync version (with_call) .
When we call this function instead of __call__ , we can return more detail information (like metadata), because __call__ only return response.

I copied it from sync version. This is with_call in sync version: https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_channel.py#L606

And sync version use _end_unary_response_blocking to check which parameters should be returned:
https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_channel.py#L615
https://github.com/grpc/grpc/blob/master/src/python/grpcio/grpc/_channel.py#L498-L502

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the code that is behind with_call will be in some how refactorized for accommodating this proposal grpc#20001.

It's true that what you have implemented can be reused later for implementing that proposal, so IMO I would keep this method - please rename it - but I would put a disclaimer like.

# TODO(https://github.com/grpc/grpc/issues/20001) This method will be removed and its logic will be used for achieving the new unified version of __call__, future and with_call
`` 

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see 👍

request,
timeout=None,
metadata=None,
credentials=None,
wait_for_ready=None,
compression=None):
if timeout:
raise NotImplementedError("TODO: timeout not implemented yet")

if credentials:
raise NotImplementedError("TODO: credentials not implemented yet")
Expand All @@ -51,10 +99,13 @@ async def __call__(self,
if compression:
raise NotImplementedError("TODO: compression not implemented yet")

response = await self._channel.unary_unary(
self._method, _common.serialize(request, self._request_serializer))
state = _RPCState(None, None)
ops = await self._channel.unary_unary(
self._method, _common.serialize(request, self._request_serializer),
metadata)
_handle_call_result(ops, state, self._response_deserializer)

return _common.deserialize(response, self._response_deserializer)
return state


class Channel(aio.Channel):
Expand Down
3 changes: 2 additions & 1 deletion src/python/grpcio_tests/tests_aio/tests.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[
"_sanity._sanity_test.AioSanityTest",
"unit.channel_test.TestChannel",
"unit.init_test.TestInsecureChannel"
"unit.init_test.TestInsecureChannel",
"unit.metadata_test.TestMetadata"
]
64 changes: 64 additions & 0 deletions src/python/grpcio_tests/tests_aio/unit/metadata_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2019 The gRPC Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import unittest

from grpc.experimental import aio
from tests_aio.unit import test_base
from tests_aio.unit import test_common
from src.proto.grpc.testing import messages_pb2


_INVOCATION_METADATA = (
(
'initial-md-key',
'initial-md-value',
),
(
'trailing-md-key-bin',
b'\x00\x02',
),
)


class TestMetadata(test_base.AioTestBase):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put the test in the same file where we have the other tests, in the end we are testing the client with metadata

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 just as a new test_unary_unary_metadata

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1


def test_unary_unary(self):
async def coro():
channel = aio.insecure_channel(self.server_target)
hi = channel.unary_unary(
'/grpc.testing.TestService/UnaryCall',
request_serializer=messages_pb2.SimpleRequest.SerializeToString,
response_deserializer=messages_pb2.SimpleResponse.FromString)
state = await hi.with_state(messages_pb2.SimpleRequest(),
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why? Why not call hi directly as we do in the other tests?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to get more information(like metadata) from __call__, I need to change the return value of it. So I add a new method to return more details.

metadata=_INVOCATION_METADATA)

self.assertEqual(type(state), aio._channel._RPCState)
self.assertEqual(type(state.response), messages_pb2.SimpleResponse)
self.assertTrue(
test_common.metadata_transmitted((_INVOCATION_METADATA[0],),
state.initial_metadata))
self.assertTrue(
test_common.metadata_transmitted((_INVOCATION_METADATA[1],),
state.trailing_metadata))

await channel.close()

self.loop.run_until_complete(coro())


if __name__ == '__main__':
logging.basicConfig()
unittest.main(verbosity=2)
19 changes: 18 additions & 1 deletion src/python/grpcio_tests/tests_aio/unit/sync_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,35 @@
import argparse

from concurrent import futures
from time import sleep

import grpc
from src.proto.grpc.testing import messages_pb2
from src.proto.grpc.testing import test_pb2_grpc


_INITIAL_METADATA_KEY = "initial-md-key"
_TRAILING_METADATA_KEY = "trailing-md-key-bin"


def _maybe_echo_metadata(servicer_context):
"""Copies metadata from request to response if it is present."""
invocation_metadata = dict(servicer_context.invocation_metadata())
if _INITIAL_METADATA_KEY in invocation_metadata:
initial_metadatum = (_INITIAL_METADATA_KEY,
invocation_metadata[_INITIAL_METADATA_KEY])
servicer_context.send_initial_metadata((initial_metadatum,))
if _TRAILING_METADATA_KEY in invocation_metadata:
trailing_metadatum = (_TRAILING_METADATA_KEY,
invocation_metadata[_TRAILING_METADATA_KEY])
servicer_context.set_trailing_metadata((trailing_metadatum,))


# TODO (https://github.com/grpc/grpc/issues/19762)
# Change for an asynchronous server version once it's implemented.
class TestServiceServicer(test_pb2_grpc.TestServiceServicer):

def UnaryCall(self, request, context):
_maybe_echo_metadata(context)
return messages_pb2.SimpleResponse()


Expand Down
76 changes: 76 additions & 0 deletions src/python/grpcio_tests/tests_aio/unit/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2019 gRPC authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common code used throughout aio tests of gRPC."""

import collections

import six

INVOCATION_INITIAL_METADATA = (
('0', 'abc'),
('1', 'def'),
('2', 'ghi'),
)
SERVICE_INITIAL_METADATA = (
('3', 'jkl'),
('4', 'mno'),
('5', 'pqr'),
)
SERVICE_TERMINAL_METADATA = (
('6', 'stu'),
('7', 'vwx'),
('8', 'yza'),
)
DETAILS = 'test details'


def metadata_transmitted(original_metadata, transmitted_metadata):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this test, what is the purpose of it? It is not importing anything from gRPC

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a function used by TestMetadata , this function help it check if the metadata of response is equal to original metadata.

"""Judges whether or not metadata was acceptably transmitted.

gRPC is allowed to insert key-value pairs into the metadata values given by
applications and to reorder key-value pairs with different keys but it is not
allowed to alter existing key-value pairs or to reorder key-value pairs with
the same key.

Args:
original_metadata: A metadata value used in a test of gRPC. An iterable over
iterables of length 2.
transmitted_metadata: A metadata value corresponding to original_metadata
after having been transmitted via gRPC. An iterable over iterables of
length 2.

Returns:
A boolean indicating whether transmitted_metadata accurately reflects
original_metadata after having been transmitted via gRPC.
"""
original = collections.defaultdict(list)
for key, value in original_metadata:
original[key].append(value)
transmitted = collections.defaultdict(list)
for key, value in transmitted_metadata:
transmitted[key].append(value)

for key, values in six.iteritems(original):
transmitted_values = transmitted[key]
transmitted_iterator = iter(transmitted_values)
try:
for value in values:
while True:
transmitted_value = next(transmitted_iterator)
if value == transmitted_value:
break
except StopIteration:
return False
else:
return True