From 087f79f3fa7cd869d92dd72b8c4e1f284c98a073 Mon Sep 17 00:00:00 2001 From: Will Miller Date: Tue, 30 Jan 2024 17:19:28 +0000 Subject: [PATCH] TLSUpgradeProto: don't set multiple results for an event In the case of a misbehaving server, the client may receive more than one byte in separate data_received() invocations from the server. While we can't do much sane with this, we should handle it gracefully and not crash with asyncio.InvalidStateError when trying to set another result on the event. Fixes #729 --- asyncpg/_testbase/__init__.py | 39 +++++++++++++++++++++++++ asyncpg/connect_utils.py | 5 ++++ tests/test_connect.py | 55 +++++++++++++++++++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 7aca834f..ee01a831 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -13,6 +13,7 @@ import logging import os import re +import socket import textwrap import time import traceback @@ -525,3 +526,41 @@ def connect_standby(cls, **kwargs): kwargs ) return pg_connection.connect(**conn_spec, loop=cls.loop) + + +class InstrumentedServer: + """ + A socket server for testing. + It will write each item from `data`, and wait for the corresponding event + in `received_events` to notify that it was received before writing the next + item from `data`. + """ + def __init__(self, data, received_events): + assert len(data) == len(received_events) + self._data = data + self._server = None + self._received_events = received_events + + async def _handle_client(self, _reader, writer): + for datum, received_event in zip(self._data, self._received_events): + writer.write(datum) + await writer.drain() + await received_event.wait() + + writer.close() + await writer.wait_closed() + + async def start(self): + """Start the server.""" + self._server = await asyncio.start_server(self._handle_client, 'localhost', 0) + assert len(self._server.sockets) == 1 + sock = self._server.sockets[0] + addr, port = sock.getsockname() + return { + 'host': addr, + 'port': port, + } + + def stop(self): + """Stop the server.""" + self._server.close() diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..5bdf6f97 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -714,6 +714,11 @@ def __init__(self, loop, host, port, ssl_context, ssl_is_advisory): self.ssl_is_advisory = ssl_is_advisory def data_received(self, data): + if self.on_data.done(): + # Only expect to receive one byte here; ignore unsolicited further + # data. + return + if data == b'S': self.on_data.set_result(True) elif (self.ssl_is_advisory and diff --git a/tests/test_connect.py b/tests/test_connect.py index 5333e2c5..b6ebca63 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -7,6 +7,7 @@ import asyncio import contextlib +import copy import gc import ipaddress import os @@ -17,11 +18,13 @@ import stat import tempfile import textwrap +import time import unittest import unittest.mock import urllib.parse import warnings import weakref +from unittest import mock import asyncpg from asyncpg import _testbase as tb @@ -1989,6 +1992,58 @@ async def test_prefer_standby_picks_master_when_standby_is_down(self): await con.close() +class TestMisbehavingServer(tb.TestCase): + """Tests for client connection behaviour given a misbehaving server.""" + + async def test_tls_upgrade_extra_data_received(self): + data = [ + # First, the server writes b"S" to signal it is willing to perform + # SSL + b"S", + # Then, the server writes an unsolicted arbitrary byte afterwards + b"N", + ] + data_received_events = [asyncio.Event() for _ in data] + + # Patch out the loop's create_connection so we can instrument the proto + # we return. + old_create_conn = self.loop.create_connection + + async def _mock_create_conn(*args, **kwargs): + transport, proto = await old_create_conn(*args, **kwargs) + old_data_received = proto.data_received + + num_received = 0 + + def _data_received(*args, **kwargs): + nonlocal num_received + # Call the original data_received method + ret = old_data_received(*args, **kwargs) + # Fire the event to signal we've received this datum now. + data_received_events[num_received].set() + num_received += 1 + return ret + + proto.data_received = _data_received + + # To deterministically provoke the race we're interested in for + # this regression test, wait for all data to be received before + # returning from create_connection(). + await data_received_events[-1].wait() + return transport, proto + + server = tb.InstrumentedServer(data, data_received_events) + conn_spec = await server.start() + + # The call to connect() should raise a ConnectionResetError as the + # server will close the connection after writing all the data. + with (mock.patch.object(self.loop, "create_connection", side_effect=_mock_create_conn), + self.assertRaises(ConnectionResetError)): + await pg_connection.connect(**conn_spec, ssl=True, loop=self.loop) + + server.stop() + + def _get_connected_host(con): peername = con._transport.get_extra_info('peername') if isinstance(peername, tuple):