diff --git a/tests/test_tcp.py b/tests/test_tcp.py index 80584a86..6db7c2b7 100644 --- a/tests/test_tcp.py +++ b/tests/test_tcp.py @@ -1263,6 +1263,95 @@ class _TestSSL(tb.SSLTestCase): PAYLOAD_SIZE = 1024 * 100 TIMEOUT = 60 + def test_start_tls_buffer_transfer(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest() + + HELLO_MSG = b'1' * self.PAYLOAD_SIZE + BUFFERED_MSG = b'buffered data before TLS' + + server_context = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + client_context = self._create_client_ssl_context() + + async def handle_client(reader, writer): + # Send data before TLS upgrade + writer.write(BUFFERED_MSG) + await writer.drain() + await asyncio.sleep(0.2) + + # Read pre-TLS data + data = await reader.readexactly(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + # Upgrade to TLS (server side) + try: + # We need the wait_for because the broken version hangs here + await asyncio.wait_for(writer.start_tls(server_context), + timeout=2 + ) + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + except asyncio.TimeoutError: + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + + # Send/receive over TLS + writer.write(b'OK') + await writer.drain() + + data = await reader.readexactly(len(HELLO_MSG)) + self.assertEqual(len(data), len(HELLO_MSG)) + + writer.close() + await self.wait_closed(writer) + + async def client(addr): + # Use open_connection for StreamReader/StreamWriter + reader, writer = await asyncio.open_connection(*addr) + + # Read buffered data before TLS + buffered = await reader.readexactly(len(BUFFERED_MSG)) + self.assertEqual(buffered, BUFFERED_MSG, + "Client didn't receive buffered data before TLS upgrade") + + # Write before TLS upgrade + writer.write(HELLO_MSG) + await writer.drain() + + # Upgrade to TLS + try: + # We need the wait_for because the broken version hangs here + await asyncio.wait_for(writer.start_tls(client_context), + timeout=2 + ) + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + except asyncio.TimeoutError: + self.assertIsNotNone(writer.get_extra_info('sslcontext')) + + # Verify communication over TLS + tls_data = await reader.readexactly(2) + self.assertEqual(tls_data, b'OK', + "Client didn't receive TLS response correctly") + + # Continue over TLS + writer.write(HELLO_MSG) + await writer.drain() + + writer.close() + await self.wait_closed(writer) + + async def run_test(): + srv = await asyncio.start_server( + handle_client, '127.0.0.1', 0, family=socket.AF_INET) + + addr = srv.sockets[0].getsockname() + + await asyncio.wait_for(client(addr), timeout=10) + + srv.close() + await srv.wait_closed() + + self.loop.run_until_complete(run_test()) + def test_create_server_ssl_1(self): CNT = 0 # number of clients that were successful TOTAL_CNT = 25 # total number of clients that test will create diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 2ed1f272..6ed6580a 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -1616,6 +1616,17 @@ cdef class Loop: ssl_shutdown_timeout=ssl_shutdown_timeout, call_connection_made=False) + # Transfer buffered data from the old protocol to the new one. + stream_buff = None + if hasattr(protocol, '_stream_reader'): + stream_reader = protocol._stream_reader + if stream_reader is not None: + stream_buff = getattr(stream_reader, '_buffer', None) + + if stream_buff is not None: + ssl_protocol._incoming.write(stream_buff) + stream_buff.clear() + # Pause early so that "ssl_protocol.data_received()" doesn't # have a chance to get called before "ssl_protocol.connection_made()". transport.pause_reading()