From ed6561de149d0590b56dbb069e29e863c31b6926 Mon Sep 17 00:00:00 2001 From: Versus Void Date: Wed, 18 Dec 2019 15:17:51 +0300 Subject: [PATCH 1/2] Restore context on listen in UVStreamServer. Fix #305 --- tests/test_context.py | 46 +++++++++++++++++++++++++++++++++ uvloop/handles/streamserver.pxd | 1 + uvloop/handles/streamserver.pyx | 5 +++- 3 files changed, 51 insertions(+), 1 deletion(-) diff --git a/tests/test_context.py b/tests/test_context.py index 4d3b12ce..ce0a456a 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -2,11 +2,23 @@ import contextvars import decimal import random +import socket import weakref from uvloop import _testbase as tb +class _Protocol(asyncio.Protocol): + def __init__(self, *, loop=None): + self.done = asyncio.Future(loop=loop) + + def connection_lost(self, exc): + if exc is None: + self.done.set_result(None) + else: + self.done.set_exception(exc) + + class _ContextBaseTests: def test_task_decimal_context(self): @@ -126,6 +138,40 @@ async def main(): del tracked self.assertIsNone(ref()) + def test_create_server_protocol_factory_context(self): + cvar = contextvars.ContextVar('cvar', default='outer') + factory_called_future = self.loop.create_future() + proto = _Protocol(loop=self.loop) + + def factory(): + try: + self.assertEqual(cvar.get(), 'inner') + except Exception as e: + factory_called_future.set_exception(e) + else: + factory_called_future.set_result(None) + + return proto + + async def test(): + cvar.set('inner') + port = tb.find_free_port() + srv = await self.loop.create_server(factory, '127.0.0.1', port) + + s = socket.socket(socket.AF_INET) + with s: + s.setblocking(False) + await self.loop.sock_connect(s, ('127.0.0.1', port)) + + try: + await factory_called_future + finally: + srv.close() + await proto.done + await srv.wait_closed() + + self.loop.run_until_complete(test()) + class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): pass diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index b2ab1887..e2093316 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -7,6 +7,7 @@ cdef class UVStreamServer(UVSocketHandle): object protocol_factory bint opened Server _server + object listen_context # All "inline" methods are final diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 7b2258dd..c1f4cd4e 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -8,6 +8,7 @@ cdef class UVStreamServer(UVSocketHandle): self.ssl_handshake_timeout = None self.ssl_shutdown_timeout = None self.protocol_factory = None + self.listen_context = None cdef inline _init(self, Loop loop, object protocol_factory, Server server, @@ -53,6 +54,8 @@ cdef class UVStreamServer(UVSocketHandle): if self.opened != 1: raise RuntimeError('unopened TCPServer') + self.listen_context = Context_CopyCurrent() + err = uv.uv_listen( self._handle, self.backlog, __uv_streamserver_on_listen) @@ -64,7 +67,7 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.protocol_factory() + protocol = self.listen_context.run(self.protocol_factory) if self.ssl is None: client = self._make_new_transport(protocol, None) From 41844b6dd50036215d461abba6bdfcb3920292ec Mon Sep 17 00:00:00 2001 From: Fantix King Date: Thu, 21 Jan 2021 15:13:12 -0500 Subject: [PATCH 2/2] Fix context in protocol callbacks (#348) This is a combined fix to correct contexts from which protocal callbacks are invoked. In short, callbacks like data_received() should always be invoked from consistent contexts which are copied from the context where the underlying UVHandle is created or started. The new test case covers also asyncio, but skipping the failing ones. --- tests/test_context.py | 623 +++++++++++++++++++++++++++++-- tests/test_sockets.py | 13 +- uvloop/cbhandles.pyx | 31 +- uvloop/handles/basetransport.pyx | 11 +- uvloop/handles/handle.pxd | 1 + uvloop/handles/pipe.pxd | 2 +- uvloop/handles/pipe.pyx | 18 +- uvloop/handles/process.pyx | 19 +- uvloop/handles/stream.pxd | 2 +- uvloop/handles/stream.pyx | 12 +- uvloop/handles/streamserver.pxd | 4 +- uvloop/handles/streamserver.pyx | 12 +- uvloop/handles/tcp.pxd | 2 +- uvloop/handles/tcp.pyx | 10 +- uvloop/handles/udp.pxd | 2 +- uvloop/handles/udp.pyx | 13 +- uvloop/loop.pyx | 48 ++- uvloop/sslproto.pxd | 17 +- uvloop/sslproto.pyx | 83 ++-- 19 files changed, 791 insertions(+), 132 deletions(-) diff --git a/tests/test_context.py b/tests/test_context.py index ce0a456a..2306eedc 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,25 +1,150 @@ import asyncio import contextvars import decimal +import itertools import random import socket +import ssl +import tempfile +import unittest import weakref from uvloop import _testbase as tb - - -class _Protocol(asyncio.Protocol): - def __init__(self, *, loop=None): +from tests.test_process import _AsyncioTests + + +class _BaseProtocol(asyncio.BaseProtocol): + def __init__(self, cvar, *, loop=None): + self.cvar = cvar + self.transport = None + self.connection_made_fut = asyncio.Future(loop=loop) + self.buffered_ctx = None + self.data_received_fut = asyncio.Future(loop=loop) + self.eof_received_fut = asyncio.Future(loop=loop) + self.pause_writing_fut = asyncio.Future(loop=loop) + self.resume_writing_fut = asyncio.Future(loop=loop) + self.pipe_ctx = {0, 1, 2} + self.pipe_connection_lost_fut = asyncio.Future(loop=loop) + self.process_exited_fut = asyncio.Future(loop=loop) + self.error_received_fut = asyncio.Future(loop=loop) + self.connection_lost_ctx = None self.done = asyncio.Future(loop=loop) + def connection_made(self, transport): + self.transport = transport + self.connection_made_fut.set_result(self.cvar.get()) + def connection_lost(self, exc): + self.connection_lost_ctx = self.cvar.get() if exc is None: self.done.set_result(None) else: self.done.set_exception(exc) + def eof_received(self): + self.eof_received_fut.set_result(self.cvar.get()) + + def pause_writing(self): + self.pause_writing_fut.set_result(self.cvar.get()) + + def resume_writing(self): + self.resume_writing_fut.set_result(self.cvar.get()) + + +class _Protocol(_BaseProtocol, asyncio.Protocol): + def data_received(self, data): + self.data_received_fut.set_result(self.cvar.get()) + + +class _BufferedProtocol(_BaseProtocol, asyncio.BufferedProtocol): + def get_buffer(self, sizehint): + if self.buffered_ctx is None: + self.buffered_ctx = self.cvar.get() + elif self.cvar.get() != self.buffered_ctx: + self.data_received_fut.set_exception(ValueError("{} != {}".format( + self.buffered_ctx, self.cvar.get(), + ))) + return bytearray(65536) + + def buffer_updated(self, nbytes): + if not self.data_received_fut.done(): + if self.cvar.get() == self.buffered_ctx: + self.data_received_fut.set_result(self.cvar.get()) + else: + self.data_received_fut.set_exception( + ValueError("{} != {}".format( + self.buffered_ctx, self.cvar.get(), + )) + ) + + +class _DatagramProtocol(_BaseProtocol, asyncio.DatagramProtocol): + def datagram_received(self, data, addr): + self.data_received_fut.set_result(self.cvar.get()) + + def error_received(self, exc): + self.error_received_fut.set_result(self.cvar.get()) + + +class _SubprocessProtocol(_BaseProtocol, asyncio.SubprocessProtocol): + def pipe_data_received(self, fd, data): + self.data_received_fut.set_result(self.cvar.get()) + + def pipe_connection_lost(self, fd, exc): + self.pipe_ctx.remove(fd) + val = self.cvar.get() + self.pipe_ctx.add(val) + if not any(isinstance(x, int) for x in self.pipe_ctx): + if len(self.pipe_ctx) == 1: + self.pipe_connection_lost_fut.set_result(val) + else: + self.pipe_connection_lost_fut.set_exception( + AssertionError(str(list(self.pipe_ctx)))) + + def process_exited(self): + self.process_exited_fut.set_result(self.cvar.get()) + + +class _SSLSocketOverSSL: + # because wrap_socket() doesn't work correctly on + # SSLSocket, we have to do the 2nd level SSL manually + + def __init__(self, ssl_sock, ctx, **kwargs): + self.sock = ssl_sock + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + self.sslobj = ctx.wrap_bio( + self.incoming, self.outgoing, **kwargs) + self.do(self.sslobj.do_handshake) + + def do(self, func, *args): + while True: + try: + rv = func(*args) + break + except ssl.SSLWantReadError: + if self.outgoing.pending: + self.sock.send(self.outgoing.read()) + self.incoming.write(self.sock.recv(65536)) + if self.outgoing.pending: + self.sock.send(self.outgoing.read()) + return rv + + def send(self, data): + self.do(self.sslobj.write, data) + + def unwrap(self): + self.do(self.sslobj.unwrap) + + def close(self): + self.sock.unwrap() + self.sock.close() -class _ContextBaseTests: + +class _ContextBaseTests(tb.SSLTestCase): + + ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') + ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_task_decimal_context(self): async def fractions(t, precision, x, y): @@ -138,30 +263,106 @@ async def main(): del tracked self.assertIsNone(ref()) - def test_create_server_protocol_factory_context(self): - cvar = contextvars.ContextVar('cvar', default='outer') - factory_called_future = self.loop.create_future() - proto = _Protocol(loop=self.loop) + def _run_test(self, method, **switches): + switches.setdefault('use_tcp', 'both') + use_ssl = switches.setdefault('use_ssl', 'no') in {'yes', 'both'} + names = ['factory'] + options = [(_Protocol, _BufferedProtocol)] + for k, v in switches.items(): + if v == 'yes': + options.append((True,)) + elif v == 'no': + options.append((False,)) + elif v == 'both': + options.append((True, False)) + else: + raise ValueError(f"Illegal {k}={v}, can only be yes/no/both") + names.append(k) + + for combo in itertools.product(*options): + values = dict(zip(names, combo)) + with self.subTest(**values): + cvar = contextvars.ContextVar('cvar', default='outer') + values['proto'] = values.pop('factory')(cvar, loop=self.loop) + + async def test(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + tmp_dir = tempfile.TemporaryDirectory() + if use_ssl: + values['sslctx'] = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + values['client_sslctx'] = \ + self._create_client_ssl_context() + else: + values['sslctx'] = values['client_sslctx'] = None + + if values['use_tcp']: + values['addr'] = ('127.0.0.1', tb.find_free_port()) + values['family'] = socket.AF_INET + else: + values['addr'] = tmp_dir.name + '/test.sock' + values['family'] = socket.AF_UNIX + + try: + await method(cvar=cvar, **values) + finally: + tmp_dir.cleanup() + + self.loop.run_until_complete(test()) + + def _run_server_test(self, method, async_sock=False, **switches): + async def test(sslctx, client_sslctx, addr, family, **values): + if values['use_tcp']: + srv = await self.loop.create_server( + lambda: values['proto'], *addr, ssl=sslctx) + else: + srv = await self.loop.create_unix_server( + lambda: values['proto'], addr, ssl=sslctx) + s = socket.socket(family) - def factory(): - try: - self.assertEqual(cvar.get(), 'inner') - except Exception as e: - factory_called_future.set_exception(e) + if async_sock: + s.setblocking(False) + await self.loop.sock_connect(s, addr) else: - factory_called_future.set_result(None) + await self.loop.run_in_executor( + None, s.connect, addr) + if values['use_ssl']: + values['ssl_sock'] = await self.loop.run_in_executor( + None, client_sslctx.wrap_socket, s) - return proto + try: + await method(s=s, **values) + finally: + if values['use_ssl']: + values['ssl_sock'].close() + s.close() + srv.close() + await srv.wait_closed() + return self._run_test(test, **switches) - async def test(): - cvar.set('inner') - port = tb.find_free_port() - srv = await self.loop.create_server(factory, '127.0.0.1', port) + def test_create_server_protocol_factory_context(self): + async def test(cvar, proto, use_tcp, family, addr, **_): + factory_called_future = self.loop.create_future() + + def factory(): + try: + self.assertEqual(cvar.get(), 'inner') + except Exception as e: + factory_called_future.set_exception(e) + else: + factory_called_future.set_result(None) - s = socket.socket(socket.AF_INET) + return proto + + if use_tcp: + srv = await self.loop.create_server(factory, *addr) + else: + srv = await self.loop.create_unix_server(factory, addr) + s = socket.socket(family) with s: s.setblocking(False) - await self.loop.sock_connect(s, ('127.0.0.1', port)) + await self.loop.sock_connect(s, addr) try: await factory_called_future @@ -170,8 +371,386 @@ async def test(): await proto.done await srv.wait_closed() + self._run_test(test) + + def test_create_server_connection_protocol(self): + async def test(proto, s, **_): + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.sock_sendall(s, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + s.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + + self._run_server_test(test, async_sock=True) + + def test_create_ssl_server_connection_protocol(self): + async def test(cvar, proto, ssl_sock, **_): + def resume_reading(transport): + cvar.set("resume_reading") + transport.resume_reading() + + try: + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if self.implementation != 'asyncio': + # this seems to be a bug in asyncio + proto.data_received_fut = self.loop.create_future() + proto.transport.pause_reading() + await self.loop.run_in_executor(None, + ssl_sock.send, b'data') + self.loop.call_soon(resume_reading, proto.transport) + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.unwrap) + else: + ssl_sock.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.close) + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + if self.implementation == 'asyncio': + # mute resource warning in asyncio + proto.transport.close() + + self._run_server_test(test, use_ssl='yes') + + def test_create_server_manual_connection_lost(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('this seems to be a bug in asyncio') + + async def test(proto, cvar, **_): + def close(): + cvar.set('closing') + proto.transport.close() + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + self.loop.call_soon(close) + + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + + self._run_server_test(test, async_sock=True) + + def test_create_ssl_server_manual_connection_lost(self): + async def test(proto, cvar, ssl_sock, **_): + def close(): + cvar.set('closing') + proto.transport.close() + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + if self.implementation == 'asyncio': + self.loop.call_soon(close) + else: + # asyncio doesn't have the flushing phase + + # put the incoming data on-hold + proto.transport.pause_reading() + # send data + await self.loop.run_in_executor(None, + ssl_sock.send, b'hello') + # schedule a proactive transport close which will trigger + # the flushing process to retrieve the remaining data + self.loop.call_soon(close) + # turn off the reading lock now (this also schedules a + # resume operation after transport.close, therefore it + # won't affect our test) + proto.transport.resume_reading() + + await asyncio.sleep(0) + await self.loop.run_in_executor(None, ssl_sock.unwrap) + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + self.assertFalse(proto.data_received_fut.done()) + + self._run_server_test(test, use_ssl='yes') + + def test_create_connection_protocol(self): + async def test(cvar, proto, addr, sslctx, client_sslctx, family, + use_sock, use_ssl, use_tcp): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + + def accept(): + sock, _ = ss.accept() + if use_ssl: + sock = sslctx.wrap_socket(sock, server_side=True) + return sock + + async def write_over(): + cvar.set("write_over") + count = 0 + if use_ssl: + proto.transport.set_write_buffer_limits(high=256, low=128) + while not proto.transport.get_write_buffer_size(): + proto.transport.write(b'q' * 16384) + count += 1 + else: + proto.transport.write(b'q' * 16384) + proto.transport.set_write_buffer_limits(high=256, low=128) + count += 1 + return count + + s = self.loop.run_in_executor(None, accept) + + try: + method = ('create_connection' if use_tcp + else 'create_unix_connection') + params = {} + if use_sock: + cs = socket.socket(family) + cs.connect(addr) + params['sock'] = cs + if use_ssl: + params['server_hostname'] = '127.0.0.1' + elif use_tcp: + params['host'] = addr[0] + params['port'] = addr[1] + else: + params['path'] = addr + if use_ssl: + params['server_hostname'] = '127.0.0.1' + if use_ssl: + params['ssl'] = client_sslctx + await getattr(self.loop, method)(lambda: proto, **params) + s = await s + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, s.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if self.implementation != 'asyncio': + # asyncio bug + count = await self.loop.create_task(write_over()) + inner = await proto.pause_writing_fut + self.assertEqual(inner, "inner") + + for i in range(count): + await self.loop.run_in_executor(None, s.recv, 16384) + inner = await proto.resume_writing_fut + self.assertEqual(inner, "inner") + + if use_ssl and self.implementation != 'asyncio': + await self.loop.run_in_executor(None, s.unwrap) + else: + s.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + ss.close() + proto.transport.close() + + self._run_test(test, use_sock='both', use_ssl='both') + + def test_start_tls(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('this seems to be a bug in asyncio') + + async def test(cvar, proto, addr, sslctx, client_sslctx, family, + ssl_over_ssl, use_tcp, **_): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + + def accept(): + sock, _ = ss.accept() + sock = sslctx.wrap_socket(sock, server_side=True) + if ssl_over_ssl: + sock = _SSLSocketOverSSL(sock, sslctx, server_side=True) + return sock + + s = self.loop.run_in_executor(None, accept) + transport = None + + try: + if use_tcp: + await self.loop.create_connection(lambda: proto, *addr) + else: + await self.loop.create_unix_connection(lambda: proto, addr) + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + cvar.set('start_tls') + transport = await self.loop.start_tls( + proto.transport, proto, client_sslctx, + server_hostname='127.0.0.1', + ) + + if ssl_over_ssl: + cvar.set('start_tls_over_tls') + transport = await self.loop.start_tls( + transport, proto, client_sslctx, + server_hostname='127.0.0.1', + ) + + s = await s + + await self.loop.run_in_executor(None, s.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, s.unwrap) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + ss.close() + if transport: + transport.close() + + self._run_test(test, use_ssl='yes', ssl_over_ssl='both') + + def test_connect_accepted_socket(self): + async def test(proto, addr, family, sslctx, client_sslctx, + use_ssl, **_): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + s = self.loop.run_in_executor(None, ss.accept) + cs = socket.socket(family) + cs.connect(addr) + s, _ = await s + + try: + if use_ssl: + cs = self.loop.run_in_executor( + None, client_sslctx.wrap_socket, cs) + await self.loop.connect_accepted_socket(lambda: proto, s, + ssl=sslctx) + cs = await cs + else: + await self.loop.connect_accepted_socket(lambda: proto, s) + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, cs.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if use_ssl and self.implementation != 'asyncio': + await self.loop.run_in_executor(None, cs.unwrap) + else: + cs.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + cs.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + ss.close() + + self._run_test(test, use_ssl='both') + + def test_subprocess_protocol(self): + cvar = contextvars.ContextVar('cvar', default='outer') + proto = _SubprocessProtocol(cvar, loop=self.loop) + + async def test(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + await self.loop.subprocess_exec(lambda: proto, + *_AsyncioTests.PROGRAM_CAT) + + try: + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + proto.transport.get_pipe_transport(0).write(b'data') + proto.transport.get_pipe_transport(0).write_eof() + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + inner = await proto.pipe_connection_lost_fut + self.assertEqual(inner, "inner") + + inner = await proto.process_exited_fut + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(inner, "inner") + + await proto.done + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + self.loop.run_until_complete(test()) + def test_datagram_protocol(self): + cvar = contextvars.ContextVar('cvar', default='outer') + proto = _DatagramProtocol(cvar, loop=self.loop) + server_addr = ('127.0.0.1', 8888) + client_addr = ('127.0.0.1', 0) + + async def run(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + + def close(): + cvar.set('closing') + proto.transport.close() + + try: + await self.loop.create_datagram_endpoint( + lambda: proto, local_addr=server_addr) + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM) + s.bind(client_addr) + s.sendto(b'data', server_addr) + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + self.loop.call_soon(close) + await proto.done + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + s.close() + # let transports close + await asyncio.sleep(0.1) + + self.loop.run_until_complete(run()) + class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): pass diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 51e3cdc8..63bdc33f 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -190,11 +190,10 @@ def test_socket_sync_remove_and_immediately_close(self): self.loop.run_until_complete(asyncio.sleep(0.01)) def test_sock_cancel_add_reader_race(self): - if self.is_asyncio_loop(): - if sys.version_info[:2] == (3, 8): - # asyncio 3.8.x has a regression; fixed in 3.9.0 - # tracked in https://bugs.python.org/issue30064 - raise unittest.SkipTest() + if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): + # asyncio 3.8.x has a regression; fixed in 3.9.0 + # tracked in https://bugs.python.org/issue30064 + raise unittest.SkipTest() srv_sock_conn = None @@ -247,8 +246,8 @@ async def send_server_data(): self.loop.run_until_complete(server()) def test_sock_send_before_cancel(self): - if self.is_asyncio_loop() and sys.version_info[:3] == (3, 8, 0): - # asyncio 3.8.0 seems to have a regression; + if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): + # asyncio 3.8.x has a regression; fixed in 3.9.0 # tracked in https://bugs.python.org/issue30064 raise unittest.SkipTest() diff --git a/uvloop/cbhandles.pyx b/uvloop/cbhandles.pyx index 2e248dd1..2fe386f7 100644 --- a/uvloop/cbhandles.pyx +++ b/uvloop/cbhandles.pyx @@ -333,71 +333,72 @@ cdef new_Handle(Loop loop, object callback, object args, object context): return handle -cdef new_MethodHandle(Loop loop, str name, method_t callback, object ctx): +cdef new_MethodHandle(Loop loop, str name, method_t callback, object context, + object bound_to): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 2 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to return handle -cdef new_MethodHandle1(Loop loop, str name, method1_t callback, - object ctx, object arg): +cdef new_MethodHandle1(Loop loop, str name, method1_t callback, object context, + object bound_to, object arg): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 3 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg return handle -cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object ctx, - object arg1, object arg2): +cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object context, + object bound_to, object arg1, object arg2): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 4 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg1 handle.arg3 = arg2 return handle -cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object ctx, - object arg1, object arg2, object arg3): +cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object context, + object bound_to, object arg1, object arg2, object arg3): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 5 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg1 handle.arg3 = arg2 handle.arg4 = arg3 diff --git a/uvloop/handles/basetransport.pyx b/uvloop/handles/basetransport.pyx index 639df186..6ddecc68 100644 --- a/uvloop/handles/basetransport.pyx +++ b/uvloop/handles/basetransport.pyx @@ -26,6 +26,7 @@ cdef class UVBaseTransport(UVSocketHandle): new_MethodHandle(self._loop, "UVTransport._call_connection_made", self._call_connection_made, + self.context, self)) cdef inline _schedule_call_connection_lost(self, exc): @@ -33,6 +34,7 @@ cdef class UVBaseTransport(UVSocketHandle): new_MethodHandle1(self._loop, "UVTransport._call_connection_lost", self._call_connection_lost, + self.context, self, exc)) cdef _fatal_error(self, exc, throw, reason=None): @@ -66,7 +68,9 @@ cdef class UVBaseTransport(UVSocketHandle): if not self._protocol_paused: self._protocol_paused = 1 try: - self._protocol.pause_writing() + # _maybe_pause_protocol() is always triggered from user-calls, + # so we must copy the context to avoid entering context twice + self.context.copy().run(self._protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -84,7 +88,10 @@ cdef class UVBaseTransport(UVSocketHandle): if self._protocol_paused and size <= self._low_water: self._protocol_paused = 0 try: - self._protocol.resume_writing() + # We're copying the context to avoid entering context twice, + # even though it's not always necessary to copy - it's easier + # to copy here than passing down a copied context. + self.context.copy().run(self._protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: diff --git a/uvloop/handles/handle.pxd b/uvloop/handles/handle.pxd index 01cbed09..5af1c14c 100644 --- a/uvloop/handles/handle.pxd +++ b/uvloop/handles/handle.pxd @@ -5,6 +5,7 @@ cdef class UVHandle: readonly _source_traceback bint _closed bint _inited + object context # Added to enable current UDPTransport implementation, # which doesn't use libuv handles. diff --git a/uvloop/handles/pipe.pxd b/uvloop/handles/pipe.pxd index 7c60fc62..56fc2658 100644 --- a/uvloop/handles/pipe.pxd +++ b/uvloop/handles/pipe.pxd @@ -14,7 +14,7 @@ cdef class UnixTransport(UVStream): @staticmethod cdef UnixTransport new(Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) cdef connect(self, char* addr) diff --git a/uvloop/handles/pipe.pyx b/uvloop/handles/pipe.pyx index bd8809a0..19dc3bd5 100644 --- a/uvloop/handles/pipe.pyx +++ b/uvloop/handles/pipe.pyx @@ -73,9 +73,11 @@ cdef class UnixServer(UVStreamServer): self._mark_as_open() - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): cdef UnixTransport tr - tr = UnixTransport.new(self._loop, protocol, self._server, waiter) + tr = UnixTransport.new(self._loop, protocol, self._server, waiter, + context) return tr @@ -84,11 +86,11 @@ cdef class UnixTransport(UVStream): @staticmethod cdef UnixTransport new(Loop loop, object protocol, Server server, - object waiter): + object waiter, object context): cdef UnixTransport handle handle = UnixTransport.__new__(UnixTransport) - handle._init(loop, protocol, server, waiter) + handle._init(loop, protocol, server, waiter, context) __pipe_init_uv_handle(handle, loop) return handle @@ -112,7 +114,9 @@ cdef class ReadUnixTransport(UVStream): object waiter): cdef ReadUnixTransport handle handle = ReadUnixTransport.__new__(ReadUnixTransport) - handle._init(loop, protocol, server, waiter) + # This is only used in connect_read_pipe() and subprocess_shell/exec() + # directly, we could simply copy the current context. + handle._init(loop, protocol, server, waiter, Context_CopyCurrent()) __pipe_init_uv_handle(handle, loop) return handle @@ -162,7 +166,9 @@ cdef class WriteUnixTransport(UVStream): # close the transport. handle._close_on_read_error() - handle._init(loop, protocol, server, waiter) + # This is only used in connect_write_pipe() and subprocess_shell/exec() + # directly, we could simply copy the current context. + handle._init(loop, protocol, server, waiter, Context_CopyCurrent()) __pipe_init_uv_handle(handle, loop) return handle diff --git a/uvloop/handles/process.pyx b/uvloop/handles/process.pyx index a7e81dfc..14931ef5 100644 --- a/uvloop/handles/process.pyx +++ b/uvloop/handles/process.pyx @@ -10,6 +10,7 @@ cdef class UVProcess(UVHandle): self._fds_to_close = set() self._preexec_fn = None self._restore_signals = True + self.context = Context_CopyCurrent() cdef _close_process_handle(self): # XXX: This is a workaround for a libuv bug: @@ -364,7 +365,8 @@ cdef class UVProcessTransport(UVProcess): UVProcess._on_exit(self, exit_status, term_signal) if self._stdio_ready: - self._loop.call_soon(self._protocol.process_exited) + self._loop.call_soon(self._protocol.process_exited, + context=self.context) else: self._pending_calls.append((_CALL_PROCESS_EXITED, None, None)) @@ -383,14 +385,16 @@ cdef class UVProcessTransport(UVProcess): cdef _pipe_connection_lost(self, int fd, exc): if self._stdio_ready: - self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc) + self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc, + context=self.context) self._try_finish() else: self._pending_calls.append((_CALL_PIPE_CONNECTION_LOST, fd, exc)) cdef _pipe_data_received(self, int fd, data): if self._stdio_ready: - self._loop.call_soon(self._protocol.pipe_data_received, fd, data) + self._loop.call_soon(self._protocol.pipe_data_received, fd, data, + context=self.context) else: self._pending_calls.append((_CALL_PIPE_DATA_RECEIVED, fd, data)) @@ -517,6 +521,7 @@ cdef class UVProcessTransport(UVProcess): cdef _call_connection_made(self, waiter): try: + # we're always called in the right context, so just call the user's self._protocol.connection_made(self) except (KeyboardInterrupt, SystemExit): raise @@ -556,7 +561,9 @@ cdef class UVProcessTransport(UVProcess): self._finished = 1 if self._stdio_ready: - self._loop.call_soon(self._protocol.connection_lost, None) + # copy self.context for simplicity + self._loop.call_soon(self._protocol.connection_lost, None, + context=self.context) else: self._pending_calls.append((_CALL_CONNECTION_LOST, None, None)) @@ -572,6 +579,7 @@ cdef class UVProcessTransport(UVProcess): new_MethodHandle1(self._loop, "UVProcessTransport._call_connection_made", self._call_connection_made, + None, # means to copy the current context self, waiter)) @staticmethod @@ -598,6 +606,8 @@ cdef class UVProcessTransport(UVProcess): if handle._init_futs: handle._stdio_ready = 0 init_fut = aio_gather(*handle._init_futs) + # add_done_callback will copy the current context and run the + # callback within the context init_fut.add_done_callback( ft_partial(handle.__stdio_inited, waiter)) else: @@ -606,6 +616,7 @@ cdef class UVProcessTransport(UVProcess): new_MethodHandle1(loop, "UVProcessTransport._call_connection_made", handle._call_connection_made, + None, # means to copy the current context handle, waiter)) return handle diff --git a/uvloop/handles/stream.pxd b/uvloop/handles/stream.pxd index 401d6f91..21ac6279 100644 --- a/uvloop/handles/stream.pxd +++ b/uvloop/handles/stream.pxd @@ -19,7 +19,7 @@ cdef class UVStream(UVBaseTransport): # All "inline" methods are final cdef inline _init(self, Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) cdef inline _exec_write(self) diff --git a/uvloop/handles/stream.pyx b/uvloop/handles/stream.pyx index 3dc53b68..fe828bde 100644 --- a/uvloop/handles/stream.pyx +++ b/uvloop/handles/stream.pyx @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport): except AttributeError: keep_open = False else: - keep_open = meth() + keep_open = self.context.run(meth) if keep_open: # We're keeping the connection open so the @@ -631,8 +631,8 @@ cdef class UVStream(UVBaseTransport): self._shutdown() cdef inline _init(self, Loop loop, object protocol, Server server, - object waiter): - + object waiter, object context): + self.context = context self._set_protocol(protocol) self._start_init(loop) @@ -826,7 +826,7 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc._protocol_data_received(loop._recv_buffer[:nread]) + sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread]) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 @@ -911,7 +911,7 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream, sc._read_pybuf_acquired = 0 try: - buf = sc._protocol_get_buffer(suggested_size) + buf = sc.context.run(sc._protocol_get_buffer, suggested_size) PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE) got_buf = 1 except BaseException as exc: @@ -976,7 +976,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc._protocol_buffer_updated(nread) + sc.context.run(sc._protocol_buffer_updated, nread) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index e2093316..a004efd9 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -7,7 +7,6 @@ cdef class UVStreamServer(UVSocketHandle): object protocol_factory bint opened Server _server - object listen_context # All "inline" methods are final @@ -23,4 +22,5 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline listen(self) cdef inline _on_listen(self) - cdef UVStream _make_new_transport(self, object protocol, object waiter) + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context) diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index c1f4cd4e..921c3565 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -8,7 +8,6 @@ cdef class UVStreamServer(UVSocketHandle): self.ssl_handshake_timeout = None self.ssl_shutdown_timeout = None self.protocol_factory = None - self.listen_context = None cdef inline _init(self, Loop loop, object protocol_factory, Server server, @@ -54,7 +53,7 @@ cdef class UVStreamServer(UVSocketHandle): if self.opened != 1: raise RuntimeError('unopened TCPServer') - self.listen_context = Context_CopyCurrent() + self.context = Context_CopyCurrent() err = uv.uv_listen( self._handle, self.backlog, @@ -67,10 +66,10 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.listen_context.run(self.protocol_factory) + protocol = self.context.run(self.protocol_factory) if self.ssl is None: - client = self._make_new_transport(protocol, None) + client = self._make_new_transport(protocol, None, self.context) else: waiter = self._loop._new_future() @@ -83,7 +82,7 @@ cdef class UVStreamServer(UVSocketHandle): ssl_handshake_timeout=self.ssl_handshake_timeout, ssl_shutdown_timeout=self.ssl_shutdown_timeout) - client = self._make_new_transport(ssl_protocol, None) + client = self._make_new_transport(ssl_protocol, None, self.context) waiter.add_done_callback( ft_partial(self.__on_ssl_connected, client)) @@ -112,7 +111,8 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _mark_as_open(self): self.opened = 1 - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): raise NotImplementedError def __on_ssl_connected(self, transport, fut): diff --git a/uvloop/handles/tcp.pxd b/uvloop/handles/tcp.pxd index 69284db3..8d388ef0 100644 --- a/uvloop/handles/tcp.pxd +++ b/uvloop/handles/tcp.pxd @@ -23,4 +23,4 @@ cdef class TCPTransport(UVStream): @staticmethod cdef TCPTransport new(Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) diff --git a/uvloop/handles/tcp.pyx b/uvloop/handles/tcp.pyx index 363e0cd8..8c65f69e 100644 --- a/uvloop/handles/tcp.pyx +++ b/uvloop/handles/tcp.pyx @@ -91,9 +91,11 @@ cdef class TCPServer(UVStreamServer): else: self._mark_as_open() - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): cdef TCPTransport tr - tr = TCPTransport.new(self._loop, protocol, self._server, waiter) + tr = TCPTransport.new(self._loop, protocol, self._server, waiter, + context) return tr @@ -102,11 +104,11 @@ cdef class TCPTransport(UVStream): @staticmethod cdef TCPTransport new(Loop loop, object protocol, Server server, - object waiter): + object waiter, object context): cdef TCPTransport handle handle = TCPTransport.__new__(TCPTransport) - handle._init(loop, protocol, server, waiter) + handle._init(loop, protocol, server, waiter, context) __tcp_init_uv_handle(handle, loop, uv.AF_UNSPEC) handle.__peername_set = 0 handle.__sockname_set = 0 diff --git a/uvloop/handles/udp.pxd b/uvloop/handles/udp.pxd index c4d2f3de..daa9a1be 100644 --- a/uvloop/handles/udp.pxd +++ b/uvloop/handles/udp.pxd @@ -19,4 +19,4 @@ cdef class UDPTransport(UVBaseTransport): cdef _send(self, object data, object addr) cdef _on_receive(self, bytes data, object exc, object addr) - cdef _on_sent(self, object exc) + cdef _on_sent(self, object exc, object context=*) diff --git a/uvloop/handles/udp.pyx b/uvloop/handles/udp.pyx index 4d590a65..82dabbf4 100644 --- a/uvloop/handles/udp.pyx +++ b/uvloop/handles/udp.pyx @@ -56,6 +56,7 @@ cdef class UDPTransport(UVBaseTransport): self._family = uv.AF_UNSPEC self.__receiving = 0 self._address = None + self.context = Context_CopyCurrent() cdef _init(self, Loop loop, unsigned int family): cdef int err @@ -252,18 +253,20 @@ cdef class UDPTransport(UVBaseTransport): exc = convert_error(err) self._fatal_error(exc, True) else: - self._on_sent(None) + self._on_sent(None, self.context.copy()) cdef _on_receive(self, bytes data, object exc, object addr): if exc is None: - self._protocol.datagram_received(data, addr) + self.context.run(self._protocol.datagram_received, data, addr) else: - self._protocol.error_received(exc) + self.context.run(self._protocol.error_received, exc) - cdef _on_sent(self, object exc): + cdef _on_sent(self, object exc, object context=None): if exc is not None: if isinstance(exc, OSError): - self._protocol.error_received(exc) + if context is None: + context = self.context + context.run(self._protocol.error_received, exc) else: self._fatal_error( exc, False, 'Fatal write error on datagram transport') diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 61191fb6..a7b24d60 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -157,7 +157,7 @@ cdef class Loop: self.handler_idle = UVIdle.new( self, new_MethodHandle( - self, "loop._on_idle", self._on_idle, self)) + self, "loop._on_idle", self._on_idle, None, self)) # Needed to call `UVStream._exec_write` for writes scheduled # during `Protocol.data_received`. @@ -165,7 +165,7 @@ cdef class Loop: self, new_MethodHandle( self, "loop._exec_queued_writes", - self._exec_queued_writes, self)) + self._exec_queued_writes, None, self)) self._signals = set() self._ssock = self._csock = None @@ -288,6 +288,7 @@ cdef class Loop: self, "Loop._read_from_self", self._read_from_self, + None, self)) self._listening_signals = True @@ -1009,6 +1010,7 @@ cdef class Loop: self, "Loop._sock_sendall", self._sock_sendall, + None, self, fut, sock, data) @@ -1049,6 +1051,7 @@ cdef class Loop: self, "Loop._sock_connect", self._sock_connect_cb, + None, self, fut, sock, address) @@ -1311,6 +1314,7 @@ cdef class Loop: self, "Loop._stop", self._stop, + None, self, None)) @@ -1531,8 +1535,11 @@ cdef class Loop: f'sslcontext is expected to be an instance of ssl.SSLContext, ' f'got {sslcontext!r}') - if not isinstance(transport, (TCPTransport, UnixTransport, - _SSLProtocolTransport)): + if isinstance(transport, (TCPTransport, UnixTransport)): + context = (transport).context + elif isinstance(transport, _SSLProtocolTransport): + context = (<_SSLProtocolTransport>transport).context + else: raise TypeError( f'transport {transport!r} is not supported by start_tls()') @@ -1549,9 +1556,12 @@ cdef class Loop: transport.pause_reading() transport.set_protocol(ssl_protocol) - conmade_cb = self.call_soon(ssl_protocol.connection_made, transport) + conmade_cb = self.call_soon(ssl_protocol.connection_made, transport, + context=context) + # transport.resume_reading() will use the right context + # (transport.context) to call e.g. data_received() resume_cb = self.call_soon(transport.resume_reading) - app_transport = ssl_protocol._get_app_transport() + app_transport = ssl_protocol._get_app_transport(context) try: await waiter @@ -1825,6 +1835,7 @@ cdef class Loop: app_protocol = protocol = protocol_factory() ssl_waiter = None + context = Context_CopyCurrent() if ssl: if server_hostname is None: if not host: @@ -1917,7 +1928,8 @@ cdef class Loop: tr = None try: waiter = self._new_future() - tr = TCPTransport.new(self, protocol, None, waiter) + tr = TCPTransport.new(self, protocol, None, waiter, + context) if lai is not NULL: lai_iter = lai @@ -1977,7 +1989,7 @@ cdef class Loop: sock.setblocking(False) waiter = self._new_future() - tr = TCPTransport.new(self, protocol, None, waiter) + tr = TCPTransport.new(self, protocol, None, waiter, context) try: # libuv will make socket non-blocking tr._open(sock.fileno()) @@ -1997,7 +2009,7 @@ cdef class Loop: tr._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(context) try: await ssl_waiter except (KeyboardInterrupt, SystemExit): @@ -2153,6 +2165,7 @@ cdef class Loop: app_protocol = protocol = protocol_factory() ssl_waiter = None + context = Context_CopyCurrent() if ssl: if server_hostname is None: raise ValueError('You must set server_hostname ' @@ -2186,7 +2199,7 @@ cdef class Loop: path = PyUnicode_EncodeFSDefault(path) waiter = self._new_future() - tr = UnixTransport.new(self, protocol, None, waiter) + tr = UnixTransport.new(self, protocol, None, waiter, context) tr.connect(path) try: await waiter @@ -2210,7 +2223,7 @@ cdef class Loop: sock.setblocking(False) waiter = self._new_future() - tr = UnixTransport.new(self, protocol, None, waiter) + tr = UnixTransport.new(self, protocol, None, waiter, context) try: tr._open(sock.fileno()) tr._init_protocol() @@ -2224,7 +2237,7 @@ cdef class Loop: tr._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(Context_CopyCurrent()) try: await ssl_waiter except (KeyboardInterrupt, SystemExit): @@ -2397,6 +2410,7 @@ cdef class Loop: self, "Loop._sock_recv", self._sock_recv, + None, self, fut, sock, n) @@ -2423,6 +2437,7 @@ cdef class Loop: self, "Loop._sock_recv_into", self._sock_recv_into, + None, self, fut, sock, buf) @@ -2474,6 +2489,7 @@ cdef class Loop: self, "Loop._sock_sendall", self._sock_sendall, + None, self, fut, sock, data) @@ -2504,6 +2520,7 @@ cdef class Loop: self, "Loop._sock_accept", self._sock_accept, + None, self, fut, sock) @@ -2569,6 +2586,7 @@ cdef class Loop: app_protocol = protocol_factory() waiter = self._new_future() transport_waiter = None + context = Context_CopyCurrent() if ssl is None: protocol = app_protocol @@ -2584,10 +2602,10 @@ cdef class Loop: if sock.family == uv.AF_UNIX: transport = UnixTransport.new( - self, protocol, None, transport_waiter) + self, protocol, None, transport_waiter, context) elif sock.family in (uv.AF_INET, uv.AF_INET6): transport = TCPTransport.new( - self, protocol, None, transport_waiter) + self, protocol, None, transport_waiter, context) if transport is None: raise ValueError( @@ -2598,7 +2616,7 @@ cdef class Loop: transport._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(context) try: await waiter except (KeyboardInterrupt, SystemExit): diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index a6daa5c0..3da10f00 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -27,6 +27,7 @@ cdef class _SSLProtocolTransport: Loop _loop SSLProtocol _ssl_protocol bint _closed + object context cdef class SSLProtocol: @@ -97,17 +98,17 @@ cdef class SSLProtocol: # Shutdown flow - cdef _start_shutdown(self) + cdef _start_shutdown(self, object context=*) cdef _check_shutdown_timeout(self) - cdef _do_read_into_void(self) - cdef _do_flush(self) - cdef _do_shutdown(self) + cdef _do_read_into_void(self, object context) + cdef _do_flush(self, object context=*) + cdef _do_shutdown(self, object context=*) cdef _on_shutdown_complete(self, shutdown_exc) cdef _abort(self, exc) # Outgoing flow - cdef _write_appdata(self, list_of_data) + cdef _write_appdata(self, list_of_data, object context) cdef _do_write(self) cdef _process_outgoing(self) @@ -116,18 +117,18 @@ cdef class SSLProtocol: cdef _do_read(self) cdef _do_read__buffered(self) cdef _do_read__copied(self) - cdef _call_eof_received(self) + cdef _call_eof_received(self, object context=*) # Flow control for writes from APP socket - cdef _control_app_writing(self) + cdef _control_app_writing(self, object context=*) cdef size_t _get_write_buffer_size(self) cdef _set_write_buffer_limits(self, high=*, low=*) # Flow control for reads to APP socket cdef _pause_reading(self) - cdef _resume_reading(self) + cdef _resume_reading(self, object context) # Flow control for reads from SSL socket diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index ac87d499..d1f976e3 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -17,11 +17,14 @@ cdef class _SSLProtocolTransport: # TODO: # _sendfile_compatible = constants._SendfileMode.FALLBACK - def __cinit__(self, Loop loop, ssl_protocol): + def __cinit__(self, Loop loop, ssl_protocol, context): self._loop = loop # SSLProtocol instance self._ssl_protocol = ssl_protocol self._closed = False + if context is None: + context = Context_CopyCurrent() + self.context = context def get_extra_info(self, name, default=None): """Get optional transport information.""" @@ -45,7 +48,7 @@ cdef class _SSLProtocolTransport: with None as its argument. """ self._closed = True - self._ssl_protocol._start_shutdown() + self._ssl_protocol._start_shutdown(self.context.copy()) def __dealloc__(self): if not self._closed: @@ -71,7 +74,7 @@ cdef class _SSLProtocolTransport: Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._resume_reading() + self._ssl_protocol._resume_reading(self.context.copy()) def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -93,7 +96,7 @@ cdef class _SSLProtocolTransport: concurrently. """ self._ssl_protocol._set_write_buffer_limits(high, low) - self._ssl_protocol._control_app_writing() + self._ssl_protocol._control_app_writing(self.context.copy()) def get_write_buffer_limits(self): return (self._ssl_protocol._outgoing_low_water, @@ -149,7 +152,7 @@ cdef class _SSLProtocolTransport: f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata((data,)) + self._ssl_protocol._write_appdata((data,), self.context.copy()) def writelines(self, list_of_data): """Write a list (or any iterable) of data bytes to the transport. @@ -157,7 +160,7 @@ cdef class _SSLProtocolTransport: The default implementation concatenates the arguments and calls write() on the result. """ - self._ssl_protocol._write_appdata(list_of_data) + self._ssl_protocol._write_appdata(list_of_data, self.context.copy()) def write_eof(self): """Close the write end after flushing buffered data. @@ -304,11 +307,12 @@ cdef class SSLProtocol: self._waiter.set_result(None) self._waiter = None - def _get_app_transport(self): + def _get_app_transport(self, context=None): if self._app_transport is None: if self._app_transport_created: raise RuntimeError('Creating _SSLProtocolTransport twice') - self._app_transport = _SSLProtocolTransport(self._loop, self) + self._app_transport = _SSLProtocolTransport(self._loop, self, + context) self._app_transport_created = True return self._app_transport @@ -540,9 +544,12 @@ cdef class SSLProtocol: # Shutdown flow - cdef _start_shutdown(self): + cdef _start_shutdown(self, object context=None): if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): return + # we don't need the context for _abort or the timeout, because + # TCP transport._force_close() should be able to call + # connection_lost() in the right context if self._app_transport is not None: self._app_transport._closed = True if self._state == DO_HANDSHAKE: @@ -552,14 +559,14 @@ cdef class SSLProtocol: self._shutdown_timeout_handle = \ self._loop.call_later(self._ssl_shutdown_timeout, lambda: self._check_shutdown_timeout()) - self._do_flush() + self._do_flush(context) cdef _check_shutdown_timeout(self): if self._state in (FLUSHING, SHUTDOWN): self._transport._force_close( aio_TimeoutError('SSL shutdown timed out')) - cdef _do_read_into_void(self): + cdef _do_read_into_void(self, object context): """Consume and discard incoming application data. If close_notify is received for the first time, call eof_received. @@ -576,9 +583,9 @@ cdef class SSLProtocol: except ssl_SSLZeroReturnError: close_notify = True if close_notify: - self._call_eof_received() + self._call_eof_received(context) - cdef _do_flush(self): + cdef _do_flush(self, object context=None): """Flush the write backlog, discarding new data received. We don't send close_notify in FLUSHING because we still want to send @@ -587,7 +594,7 @@ cdef class SSLProtocol: in FLUSHING, as we could fully manage the flow control internally. """ try: - self._do_read_into_void() + self._do_read_into_void(context) self._do_write() self._process_outgoing() self._control_ssl_reading() @@ -596,13 +603,13 @@ cdef class SSLProtocol: else: if not self._get_write_buffer_size(): self._set_state(SHUTDOWN) - self._do_shutdown() + self._do_shutdown(context) - cdef _do_shutdown(self): + cdef _do_shutdown(self, object context=None): """Send close_notify and wait for the same from the peer.""" try: # we must skip all application data (if any) before unwrap - self._do_read_into_void() + self._do_read_into_void(context) try: self._sslobj.unwrap() except ssl_SSLAgainErrors as exc: @@ -619,6 +626,8 @@ cdef class SSLProtocol: self._shutdown_timeout_handle.cancel() self._shutdown_timeout_handle = None + # we don't need the context here because TCP transport.close() should + # be able to call connection_made() in the right context if shutdown_exc: self._fatal_error(shutdown_exc, 'Error occurred during shutdown') else: @@ -631,7 +640,7 @@ cdef class SSLProtocol: # Outgoing flow - cdef _write_appdata(self, list_of_data): + cdef _write_appdata(self, list_of_data, object context): if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): if self._conn_lost >= LOG_THRESHOLD_FOR_CONNLOST_WRITES: aio_logger.warning('SSL connection is closed') @@ -646,7 +655,7 @@ cdef class SSLProtocol: if self._state == WRAPPED: self._do_write() self._process_outgoing() - self._control_app_writing() + self._control_app_writing(context) except Exception as ex: self._fatal_error(ex, 'Fatal error on SSL protocol') @@ -730,6 +739,7 @@ cdef class SSLProtocol: new_MethodHandle(self._loop, "SSLProtocol._do_read", self._do_read, + None, # current context is good self)) except ssl_SSLAgainErrors as exc: pass @@ -774,11 +784,17 @@ cdef class SSLProtocol: self._call_eof_received() self._start_shutdown() - cdef _call_eof_received(self): + cdef _call_eof_received(self, object context=None): if self._app_state == STATE_CON_MADE: self._app_state = STATE_EOF try: - keep_open = self._app_protocol.eof_received() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like buffer_updated() + keep_open = self._app_protocol.eof_received() + else: + keep_open = context.run(self._app_protocol.eof_received) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: @@ -790,12 +806,18 @@ cdef class SSLProtocol: # Flow control for writes from APP socket - cdef _control_app_writing(self): + cdef _control_app_writing(self, object context=None): cdef size_t size = self._get_write_buffer_size() if size >= self._outgoing_high_water and not self._app_writing_paused: self._app_writing_paused = True try: - self._app_protocol.pause_writing() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like buffer_updated() + self._app_protocol.pause_writing() + else: + context.run(self._app_protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -808,7 +830,13 @@ cdef class SSLProtocol: elif size <= self._outgoing_low_water and self._app_writing_paused: self._app_writing_paused = False try: - self._app_protocol.resume_writing() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like resume_writing() + self._app_protocol.resume_writing() + else: + context.run(self._app_protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -833,7 +861,7 @@ cdef class SSLProtocol: cdef _pause_reading(self): self._app_reading_paused = True - cdef _resume_reading(self): + cdef _resume_reading(self, object context): if self._app_reading_paused: self._app_reading_paused = False if self._state == WRAPPED: @@ -841,6 +869,7 @@ cdef class SSLProtocol: new_MethodHandle(self._loop, "SSLProtocol._do_read", self._do_read, + context, self)) # Flow control for reads from SSL socket @@ -890,7 +919,9 @@ cdef class SSLProtocol: self._do_shutdown() cdef _fatal_error(self, exc, message='Fatal error on transport'): - if self._transport: + if self._app_transport: + self._app_transport._force_close(exc) + elif self._transport: self._transport._force_close(exc) if isinstance(exc, OSError):