diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e64c1534..d13f1510 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -87,7 +87,7 @@ async def send_message( metadata=proto_utils.ToProto.metadata(request.metadata), ) ) - if response.task: + if response.HasField('task'): return proto_utils.FromProto.task(response.task) return proto_utils.FromProto.message(response.msg) diff --git a/tests/client/test_grpc_client.py b/tests/client/test_grpc_client.py index c2dbc2b8..259ac75e 100644 --- a/tests/client/test_grpc_client.py +++ b/tests/client/test_grpc_client.py @@ -18,7 +18,7 @@ TaskStatus, TextPart, ) -from a2a.utils import proto_utils +from a2a.utils import get_text_parts, proto_utils # Fixtures @@ -112,6 +112,28 @@ async def test_send_message_task_response( assert response.id == sample_task.id +@pytest.mark.asyncio +async def test_send_message_message_response( + grpc_transport: GrpcTransport, + mock_grpc_stub: AsyncMock, + sample_message_send_params: MessageSendParams, + sample_message: Message, +): + """Test send_message that returns a Message.""" + mock_grpc_stub.SendMessage.return_value = a2a_pb2.SendMessageResponse( + msg=proto_utils.ToProto.message(sample_message) + ) + + response = await grpc_transport.send_message(sample_message_send_params) + + mock_grpc_stub.SendMessage.assert_awaited_once() + assert isinstance(response, Message) + assert response.message_id == sample_message.message_id + assert get_text_parts(response.parts) == get_text_parts( + sample_message.parts + ) + + @pytest.mark.asyncio async def test_get_task( grpc_transport: GrpcTransport, mock_grpc_stub: AsyncMock, sample_task: Task