diff --git a/tests/test_context.py b/tests/test_context.py index 4d3b12ce..2306eedc 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -1,13 +1,150 @@ import asyncio import contextvars import decimal +import itertools import random +import socket +import ssl +import tempfile +import unittest import weakref from uvloop import _testbase as tb +from tests.test_process import _AsyncioTests + + +class _BaseProtocol(asyncio.BaseProtocol): + def __init__(self, cvar, *, loop=None): + self.cvar = cvar + self.transport = None + self.connection_made_fut = asyncio.Future(loop=loop) + self.buffered_ctx = None + self.data_received_fut = asyncio.Future(loop=loop) + self.eof_received_fut = asyncio.Future(loop=loop) + self.pause_writing_fut = asyncio.Future(loop=loop) + self.resume_writing_fut = asyncio.Future(loop=loop) + self.pipe_ctx = {0, 1, 2} + self.pipe_connection_lost_fut = asyncio.Future(loop=loop) + self.process_exited_fut = asyncio.Future(loop=loop) + self.error_received_fut = asyncio.Future(loop=loop) + self.connection_lost_ctx = None + self.done = asyncio.Future(loop=loop) + + def connection_made(self, transport): + self.transport = transport + self.connection_made_fut.set_result(self.cvar.get()) + + def connection_lost(self, exc): + self.connection_lost_ctx = self.cvar.get() + if exc is None: + self.done.set_result(None) + else: + self.done.set_exception(exc) + + def eof_received(self): + self.eof_received_fut.set_result(self.cvar.get()) + + def pause_writing(self): + self.pause_writing_fut.set_result(self.cvar.get()) + + def resume_writing(self): + self.resume_writing_fut.set_result(self.cvar.get()) + + +class _Protocol(_BaseProtocol, asyncio.Protocol): + def data_received(self, data): + self.data_received_fut.set_result(self.cvar.get()) + + +class _BufferedProtocol(_BaseProtocol, asyncio.BufferedProtocol): + def get_buffer(self, sizehint): + if self.buffered_ctx is None: + self.buffered_ctx = self.cvar.get() + elif self.cvar.get() != self.buffered_ctx: + self.data_received_fut.set_exception(ValueError("{} != {}".format( + self.buffered_ctx, self.cvar.get(), + ))) + return bytearray(65536) + + def buffer_updated(self, nbytes): + if not self.data_received_fut.done(): + if self.cvar.get() == self.buffered_ctx: + self.data_received_fut.set_result(self.cvar.get()) + else: + self.data_received_fut.set_exception( + ValueError("{} != {}".format( + self.buffered_ctx, self.cvar.get(), + )) + ) + + +class _DatagramProtocol(_BaseProtocol, asyncio.DatagramProtocol): + def datagram_received(self, data, addr): + self.data_received_fut.set_result(self.cvar.get()) + + def error_received(self, exc): + self.error_received_fut.set_result(self.cvar.get()) + + +class _SubprocessProtocol(_BaseProtocol, asyncio.SubprocessProtocol): + def pipe_data_received(self, fd, data): + self.data_received_fut.set_result(self.cvar.get()) + + def pipe_connection_lost(self, fd, exc): + self.pipe_ctx.remove(fd) + val = self.cvar.get() + self.pipe_ctx.add(val) + if not any(isinstance(x, int) for x in self.pipe_ctx): + if len(self.pipe_ctx) == 1: + self.pipe_connection_lost_fut.set_result(val) + else: + self.pipe_connection_lost_fut.set_exception( + AssertionError(str(list(self.pipe_ctx)))) + + def process_exited(self): + self.process_exited_fut.set_result(self.cvar.get()) + + +class _SSLSocketOverSSL: + # because wrap_socket() doesn't work correctly on + # SSLSocket, we have to do the 2nd level SSL manually + + def __init__(self, ssl_sock, ctx, **kwargs): + self.sock = ssl_sock + self.incoming = ssl.MemoryBIO() + self.outgoing = ssl.MemoryBIO() + self.sslobj = ctx.wrap_bio( + self.incoming, self.outgoing, **kwargs) + self.do(self.sslobj.do_handshake) + + def do(self, func, *args): + while True: + try: + rv = func(*args) + break + except ssl.SSLWantReadError: + if self.outgoing.pending: + self.sock.send(self.outgoing.read()) + self.incoming.write(self.sock.recv(65536)) + if self.outgoing.pending: + self.sock.send(self.outgoing.read()) + return rv + + def send(self, data): + self.do(self.sslobj.write, data) + + def unwrap(self): + self.do(self.sslobj.unwrap) + + def close(self): + self.sock.unwrap() + self.sock.close() -class _ContextBaseTests: +class _ContextBaseTests(tb.SSLTestCase): + + ONLYCERT = tb._cert_fullname(__file__, 'ssl_cert.pem') + ONLYKEY = tb._cert_fullname(__file__, 'ssl_key.pem') def test_task_decimal_context(self): async def fractions(t, precision, x, y): @@ -126,6 +263,494 @@ async def main(): del tracked self.assertIsNone(ref()) + def _run_test(self, method, **switches): + switches.setdefault('use_tcp', 'both') + use_ssl = switches.setdefault('use_ssl', 'no') in {'yes', 'both'} + names = ['factory'] + options = [(_Protocol, _BufferedProtocol)] + for k, v in switches.items(): + if v == 'yes': + options.append((True,)) + elif v == 'no': + options.append((False,)) + elif v == 'both': + options.append((True, False)) + else: + raise ValueError(f"Illegal {k}={v}, can only be yes/no/both") + names.append(k) + + for combo in itertools.product(*options): + values = dict(zip(names, combo)) + with self.subTest(**values): + cvar = contextvars.ContextVar('cvar', default='outer') + values['proto'] = values.pop('factory')(cvar, loop=self.loop) + + async def test(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + tmp_dir = tempfile.TemporaryDirectory() + if use_ssl: + values['sslctx'] = self._create_server_ssl_context( + self.ONLYCERT, self.ONLYKEY) + values['client_sslctx'] = \ + self._create_client_ssl_context() + else: + values['sslctx'] = values['client_sslctx'] = None + + if values['use_tcp']: + values['addr'] = ('127.0.0.1', tb.find_free_port()) + values['family'] = socket.AF_INET + else: + values['addr'] = tmp_dir.name + '/test.sock' + values['family'] = socket.AF_UNIX + + try: + await method(cvar=cvar, **values) + finally: + tmp_dir.cleanup() + + self.loop.run_until_complete(test()) + + def _run_server_test(self, method, async_sock=False, **switches): + async def test(sslctx, client_sslctx, addr, family, **values): + if values['use_tcp']: + srv = await self.loop.create_server( + lambda: values['proto'], *addr, ssl=sslctx) + else: + srv = await self.loop.create_unix_server( + lambda: values['proto'], addr, ssl=sslctx) + s = socket.socket(family) + + if async_sock: + s.setblocking(False) + await self.loop.sock_connect(s, addr) + else: + await self.loop.run_in_executor( + None, s.connect, addr) + if values['use_ssl']: + values['ssl_sock'] = await self.loop.run_in_executor( + None, client_sslctx.wrap_socket, s) + + try: + await method(s=s, **values) + finally: + if values['use_ssl']: + values['ssl_sock'].close() + s.close() + srv.close() + await srv.wait_closed() + return self._run_test(test, **switches) + + def test_create_server_protocol_factory_context(self): + async def test(cvar, proto, use_tcp, family, addr, **_): + factory_called_future = self.loop.create_future() + + def factory(): + try: + self.assertEqual(cvar.get(), 'inner') + except Exception as e: + factory_called_future.set_exception(e) + else: + factory_called_future.set_result(None) + + return proto + + if use_tcp: + srv = await self.loop.create_server(factory, *addr) + else: + srv = await self.loop.create_unix_server(factory, addr) + s = socket.socket(family) + with s: + s.setblocking(False) + await self.loop.sock_connect(s, addr) + + try: + await factory_called_future + finally: + srv.close() + await proto.done + await srv.wait_closed() + + self._run_test(test) + + def test_create_server_connection_protocol(self): + async def test(proto, s, **_): + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.sock_sendall(s, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + s.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + + self._run_server_test(test, async_sock=True) + + def test_create_ssl_server_connection_protocol(self): + async def test(cvar, proto, ssl_sock, **_): + def resume_reading(transport): + cvar.set("resume_reading") + transport.resume_reading() + + try: + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if self.implementation != 'asyncio': + # this seems to be a bug in asyncio + proto.data_received_fut = self.loop.create_future() + proto.transport.pause_reading() + await self.loop.run_in_executor(None, + ssl_sock.send, b'data') + self.loop.call_soon(resume_reading, proto.transport) + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.unwrap) + else: + ssl_sock.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, ssl_sock.close) + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + if self.implementation == 'asyncio': + # mute resource warning in asyncio + proto.transport.close() + + self._run_server_test(test, use_ssl='yes') + + def test_create_server_manual_connection_lost(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('this seems to be a bug in asyncio') + + async def test(proto, cvar, **_): + def close(): + cvar.set('closing') + proto.transport.close() + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + self.loop.call_soon(close) + + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + + self._run_server_test(test, async_sock=True) + + def test_create_ssl_server_manual_connection_lost(self): + async def test(proto, cvar, ssl_sock, **_): + def close(): + cvar.set('closing') + proto.transport.close() + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + if self.implementation == 'asyncio': + self.loop.call_soon(close) + else: + # asyncio doesn't have the flushing phase + + # put the incoming data on-hold + proto.transport.pause_reading() + # send data + await self.loop.run_in_executor(None, + ssl_sock.send, b'hello') + # schedule a proactive transport close which will trigger + # the flushing process to retrieve the remaining data + self.loop.call_soon(close) + # turn off the reading lock now (this also schedules a + # resume operation after transport.close, therefore it + # won't affect our test) + proto.transport.resume_reading() + + await asyncio.sleep(0) + await self.loop.run_in_executor(None, ssl_sock.unwrap) + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + self.assertFalse(proto.data_received_fut.done()) + + self._run_server_test(test, use_ssl='yes') + + def test_create_connection_protocol(self): + async def test(cvar, proto, addr, sslctx, client_sslctx, family, + use_sock, use_ssl, use_tcp): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + + def accept(): + sock, _ = ss.accept() + if use_ssl: + sock = sslctx.wrap_socket(sock, server_side=True) + return sock + + async def write_over(): + cvar.set("write_over") + count = 0 + if use_ssl: + proto.transport.set_write_buffer_limits(high=256, low=128) + while not proto.transport.get_write_buffer_size(): + proto.transport.write(b'q' * 16384) + count += 1 + else: + proto.transport.write(b'q' * 16384) + proto.transport.set_write_buffer_limits(high=256, low=128) + count += 1 + return count + + s = self.loop.run_in_executor(None, accept) + + try: + method = ('create_connection' if use_tcp + else 'create_unix_connection') + params = {} + if use_sock: + cs = socket.socket(family) + cs.connect(addr) + params['sock'] = cs + if use_ssl: + params['server_hostname'] = '127.0.0.1' + elif use_tcp: + params['host'] = addr[0] + params['port'] = addr[1] + else: + params['path'] = addr + if use_ssl: + params['server_hostname'] = '127.0.0.1' + if use_ssl: + params['ssl'] = client_sslctx + await getattr(self.loop, method)(lambda: proto, **params) + s = await s + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, s.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if self.implementation != 'asyncio': + # asyncio bug + count = await self.loop.create_task(write_over()) + inner = await proto.pause_writing_fut + self.assertEqual(inner, "inner") + + for i in range(count): + await self.loop.run_in_executor(None, s.recv, 16384) + inner = await proto.resume_writing_fut + self.assertEqual(inner, "inner") + + if use_ssl and self.implementation != 'asyncio': + await self.loop.run_in_executor(None, s.unwrap) + else: + s.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + ss.close() + proto.transport.close() + + self._run_test(test, use_sock='both', use_ssl='both') + + def test_start_tls(self): + if self.implementation == 'asyncio': + raise unittest.SkipTest('this seems to be a bug in asyncio') + + async def test(cvar, proto, addr, sslctx, client_sslctx, family, + ssl_over_ssl, use_tcp, **_): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + + def accept(): + sock, _ = ss.accept() + sock = sslctx.wrap_socket(sock, server_side=True) + if ssl_over_ssl: + sock = _SSLSocketOverSSL(sock, sslctx, server_side=True) + return sock + + s = self.loop.run_in_executor(None, accept) + transport = None + + try: + if use_tcp: + await self.loop.create_connection(lambda: proto, *addr) + else: + await self.loop.create_unix_connection(lambda: proto, addr) + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + cvar.set('start_tls') + transport = await self.loop.start_tls( + proto.transport, proto, client_sslctx, + server_hostname='127.0.0.1', + ) + + if ssl_over_ssl: + cvar.set('start_tls_over_tls') + transport = await self.loop.start_tls( + transport, proto, client_sslctx, + server_hostname='127.0.0.1', + ) + + s = await s + + await self.loop.run_in_executor(None, s.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, s.unwrap) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + s.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + ss.close() + if transport: + transport.close() + + self._run_test(test, use_ssl='yes', ssl_over_ssl='both') + + def test_connect_accepted_socket(self): + async def test(proto, addr, family, sslctx, client_sslctx, + use_ssl, **_): + ss = socket.socket(family) + ss.bind(addr) + ss.listen(1) + s = self.loop.run_in_executor(None, ss.accept) + cs = socket.socket(family) + cs.connect(addr) + s, _ = await s + + try: + if use_ssl: + cs = self.loop.run_in_executor( + None, client_sslctx.wrap_socket, cs) + await self.loop.connect_accepted_socket(lambda: proto, s, + ssl=sslctx) + cs = await cs + else: + await self.loop.connect_accepted_socket(lambda: proto, s) + + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + await self.loop.run_in_executor(None, cs.send, b'data') + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + if use_ssl and self.implementation != 'asyncio': + await self.loop.run_in_executor(None, cs.unwrap) + else: + cs.shutdown(socket.SHUT_WR) + inner = await proto.eof_received_fut + self.assertEqual(inner, "inner") + + cs.close() + await proto.done + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + ss.close() + + self._run_test(test, use_ssl='both') + + def test_subprocess_protocol(self): + cvar = contextvars.ContextVar('cvar', default='outer') + proto = _SubprocessProtocol(cvar, loop=self.loop) + + async def test(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + await self.loop.subprocess_exec(lambda: proto, + *_AsyncioTests.PROGRAM_CAT) + + try: + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + proto.transport.get_pipe_transport(0).write(b'data') + proto.transport.get_pipe_transport(0).write_eof() + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + inner = await proto.pipe_connection_lost_fut + self.assertEqual(inner, "inner") + + inner = await proto.process_exited_fut + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(inner, "inner") + + await proto.done + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + + self.loop.run_until_complete(test()) + + def test_datagram_protocol(self): + cvar = contextvars.ContextVar('cvar', default='outer') + proto = _DatagramProtocol(cvar, loop=self.loop) + server_addr = ('127.0.0.1', 8888) + client_addr = ('127.0.0.1', 0) + + async def run(): + self.assertEqual(cvar.get(), 'outer') + cvar.set('inner') + + def close(): + cvar.set('closing') + proto.transport.close() + + try: + await self.loop.create_datagram_endpoint( + lambda: proto, local_addr=server_addr) + inner = await proto.connection_made_fut + self.assertEqual(inner, "inner") + + s = socket.socket(socket.AF_INET, type=socket.SOCK_DGRAM) + s.bind(client_addr) + s.sendto(b'data', server_addr) + inner = await proto.data_received_fut + self.assertEqual(inner, "inner") + + self.loop.call_soon(close) + await proto.done + if self.implementation != 'asyncio': + # bug in asyncio + self.assertEqual(proto.connection_lost_ctx, "inner") + finally: + proto.transport.close() + s.close() + # let transports close + await asyncio.sleep(0.1) + + self.loop.run_until_complete(run()) + class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): pass diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 51e3cdc8..63bdc33f 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -190,11 +190,10 @@ def test_socket_sync_remove_and_immediately_close(self): self.loop.run_until_complete(asyncio.sleep(0.01)) def test_sock_cancel_add_reader_race(self): - if self.is_asyncio_loop(): - if sys.version_info[:2] == (3, 8): - # asyncio 3.8.x has a regression; fixed in 3.9.0 - # tracked in https://bugs.python.org/issue30064 - raise unittest.SkipTest() + if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): + # asyncio 3.8.x has a regression; fixed in 3.9.0 + # tracked in https://bugs.python.org/issue30064 + raise unittest.SkipTest() srv_sock_conn = None @@ -247,8 +246,8 @@ async def send_server_data(): self.loop.run_until_complete(server()) def test_sock_send_before_cancel(self): - if self.is_asyncio_loop() and sys.version_info[:3] == (3, 8, 0): - # asyncio 3.8.0 seems to have a regression; + if self.is_asyncio_loop() and sys.version_info[:2] == (3, 8): + # asyncio 3.8.x has a regression; fixed in 3.9.0 # tracked in https://bugs.python.org/issue30064 raise unittest.SkipTest() diff --git a/uvloop/cbhandles.pyx b/uvloop/cbhandles.pyx index 2e248dd1..2fe386f7 100644 --- a/uvloop/cbhandles.pyx +++ b/uvloop/cbhandles.pyx @@ -333,71 +333,72 @@ cdef new_Handle(Loop loop, object callback, object args, object context): return handle -cdef new_MethodHandle(Loop loop, str name, method_t callback, object ctx): +cdef new_MethodHandle(Loop loop, str name, method_t callback, object context, + object bound_to): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 2 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to return handle -cdef new_MethodHandle1(Loop loop, str name, method1_t callback, - object ctx, object arg): +cdef new_MethodHandle1(Loop loop, str name, method1_t callback, object context, + object bound_to, object arg): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 3 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg return handle -cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object ctx, - object arg1, object arg2): +cdef new_MethodHandle2(Loop loop, str name, method2_t callback, object context, + object bound_to, object arg1, object arg2): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 4 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg1 handle.arg3 = arg2 return handle -cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object ctx, - object arg1, object arg2, object arg3): +cdef new_MethodHandle3(Loop loop, str name, method3_t callback, object context, + object bound_to, object arg1, object arg2, object arg3): cdef Handle handle handle = Handle.__new__(Handle) handle._set_loop(loop) - handle._set_context(None) + handle._set_context(context) handle.cb_type = 5 handle.meth_name = name handle.callback = callback - handle.arg1 = ctx + handle.arg1 = bound_to handle.arg2 = arg1 handle.arg3 = arg2 handle.arg4 = arg3 diff --git a/uvloop/handles/basetransport.pyx b/uvloop/handles/basetransport.pyx index 639df186..6ddecc68 100644 --- a/uvloop/handles/basetransport.pyx +++ b/uvloop/handles/basetransport.pyx @@ -26,6 +26,7 @@ cdef class UVBaseTransport(UVSocketHandle): new_MethodHandle(self._loop, "UVTransport._call_connection_made", self._call_connection_made, + self.context, self)) cdef inline _schedule_call_connection_lost(self, exc): @@ -33,6 +34,7 @@ cdef class UVBaseTransport(UVSocketHandle): new_MethodHandle1(self._loop, "UVTransport._call_connection_lost", self._call_connection_lost, + self.context, self, exc)) cdef _fatal_error(self, exc, throw, reason=None): @@ -66,7 +68,9 @@ cdef class UVBaseTransport(UVSocketHandle): if not self._protocol_paused: self._protocol_paused = 1 try: - self._protocol.pause_writing() + # _maybe_pause_protocol() is always triggered from user-calls, + # so we must copy the context to avoid entering context twice + self.context.copy().run(self._protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -84,7 +88,10 @@ cdef class UVBaseTransport(UVSocketHandle): if self._protocol_paused and size <= self._low_water: self._protocol_paused = 0 try: - self._protocol.resume_writing() + # We're copying the context to avoid entering context twice, + # even though it's not always necessary to copy - it's easier + # to copy here than passing down a copied context. + self.context.copy().run(self._protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: diff --git a/uvloop/handles/handle.pxd b/uvloop/handles/handle.pxd index 01cbed09..5af1c14c 100644 --- a/uvloop/handles/handle.pxd +++ b/uvloop/handles/handle.pxd @@ -5,6 +5,7 @@ cdef class UVHandle: readonly _source_traceback bint _closed bint _inited + object context # Added to enable current UDPTransport implementation, # which doesn't use libuv handles. diff --git a/uvloop/handles/pipe.pxd b/uvloop/handles/pipe.pxd index 7c60fc62..56fc2658 100644 --- a/uvloop/handles/pipe.pxd +++ b/uvloop/handles/pipe.pxd @@ -14,7 +14,7 @@ cdef class UnixTransport(UVStream): @staticmethod cdef UnixTransport new(Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) cdef connect(self, char* addr) diff --git a/uvloop/handles/pipe.pyx b/uvloop/handles/pipe.pyx index bd8809a0..19dc3bd5 100644 --- a/uvloop/handles/pipe.pyx +++ b/uvloop/handles/pipe.pyx @@ -73,9 +73,11 @@ cdef class UnixServer(UVStreamServer): self._mark_as_open() - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): cdef UnixTransport tr - tr = UnixTransport.new(self._loop, protocol, self._server, waiter) + tr = UnixTransport.new(self._loop, protocol, self._server, waiter, + context) return tr @@ -84,11 +86,11 @@ cdef class UnixTransport(UVStream): @staticmethod cdef UnixTransport new(Loop loop, object protocol, Server server, - object waiter): + object waiter, object context): cdef UnixTransport handle handle = UnixTransport.__new__(UnixTransport) - handle._init(loop, protocol, server, waiter) + handle._init(loop, protocol, server, waiter, context) __pipe_init_uv_handle(handle, loop) return handle @@ -112,7 +114,9 @@ cdef class ReadUnixTransport(UVStream): object waiter): cdef ReadUnixTransport handle handle = ReadUnixTransport.__new__(ReadUnixTransport) - handle._init(loop, protocol, server, waiter) + # This is only used in connect_read_pipe() and subprocess_shell/exec() + # directly, we could simply copy the current context. + handle._init(loop, protocol, server, waiter, Context_CopyCurrent()) __pipe_init_uv_handle(handle, loop) return handle @@ -162,7 +166,9 @@ cdef class WriteUnixTransport(UVStream): # close the transport. handle._close_on_read_error() - handle._init(loop, protocol, server, waiter) + # This is only used in connect_write_pipe() and subprocess_shell/exec() + # directly, we could simply copy the current context. + handle._init(loop, protocol, server, waiter, Context_CopyCurrent()) __pipe_init_uv_handle(handle, loop) return handle diff --git a/uvloop/handles/process.pyx b/uvloop/handles/process.pyx index a7e81dfc..14931ef5 100644 --- a/uvloop/handles/process.pyx +++ b/uvloop/handles/process.pyx @@ -10,6 +10,7 @@ cdef class UVProcess(UVHandle): self._fds_to_close = set() self._preexec_fn = None self._restore_signals = True + self.context = Context_CopyCurrent() cdef _close_process_handle(self): # XXX: This is a workaround for a libuv bug: @@ -364,7 +365,8 @@ cdef class UVProcessTransport(UVProcess): UVProcess._on_exit(self, exit_status, term_signal) if self._stdio_ready: - self._loop.call_soon(self._protocol.process_exited) + self._loop.call_soon(self._protocol.process_exited, + context=self.context) else: self._pending_calls.append((_CALL_PROCESS_EXITED, None, None)) @@ -383,14 +385,16 @@ cdef class UVProcessTransport(UVProcess): cdef _pipe_connection_lost(self, int fd, exc): if self._stdio_ready: - self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc) + self._loop.call_soon(self._protocol.pipe_connection_lost, fd, exc, + context=self.context) self._try_finish() else: self._pending_calls.append((_CALL_PIPE_CONNECTION_LOST, fd, exc)) cdef _pipe_data_received(self, int fd, data): if self._stdio_ready: - self._loop.call_soon(self._protocol.pipe_data_received, fd, data) + self._loop.call_soon(self._protocol.pipe_data_received, fd, data, + context=self.context) else: self._pending_calls.append((_CALL_PIPE_DATA_RECEIVED, fd, data)) @@ -517,6 +521,7 @@ cdef class UVProcessTransport(UVProcess): cdef _call_connection_made(self, waiter): try: + # we're always called in the right context, so just call the user's self._protocol.connection_made(self) except (KeyboardInterrupt, SystemExit): raise @@ -556,7 +561,9 @@ cdef class UVProcessTransport(UVProcess): self._finished = 1 if self._stdio_ready: - self._loop.call_soon(self._protocol.connection_lost, None) + # copy self.context for simplicity + self._loop.call_soon(self._protocol.connection_lost, None, + context=self.context) else: self._pending_calls.append((_CALL_CONNECTION_LOST, None, None)) @@ -572,6 +579,7 @@ cdef class UVProcessTransport(UVProcess): new_MethodHandle1(self._loop, "UVProcessTransport._call_connection_made", self._call_connection_made, + None, # means to copy the current context self, waiter)) @staticmethod @@ -598,6 +606,8 @@ cdef class UVProcessTransport(UVProcess): if handle._init_futs: handle._stdio_ready = 0 init_fut = aio_gather(*handle._init_futs) + # add_done_callback will copy the current context and run the + # callback within the context init_fut.add_done_callback( ft_partial(handle.__stdio_inited, waiter)) else: @@ -606,6 +616,7 @@ cdef class UVProcessTransport(UVProcess): new_MethodHandle1(loop, "UVProcessTransport._call_connection_made", handle._call_connection_made, + None, # means to copy the current context handle, waiter)) return handle diff --git a/uvloop/handles/stream.pxd b/uvloop/handles/stream.pxd index 401d6f91..21ac6279 100644 --- a/uvloop/handles/stream.pxd +++ b/uvloop/handles/stream.pxd @@ -19,7 +19,7 @@ cdef class UVStream(UVBaseTransport): # All "inline" methods are final cdef inline _init(self, Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) cdef inline _exec_write(self) diff --git a/uvloop/handles/stream.pyx b/uvloop/handles/stream.pyx index 3dc53b68..fe828bde 100644 --- a/uvloop/handles/stream.pyx +++ b/uvloop/handles/stream.pyx @@ -612,7 +612,7 @@ cdef class UVStream(UVBaseTransport): except AttributeError: keep_open = False else: - keep_open = meth() + keep_open = self.context.run(meth) if keep_open: # We're keeping the connection open so the @@ -631,8 +631,8 @@ cdef class UVStream(UVBaseTransport): self._shutdown() cdef inline _init(self, Loop loop, object protocol, Server server, - object waiter): - + object waiter, object context): + self.context = context self._set_protocol(protocol) self._start_init(loop) @@ -826,7 +826,7 @@ cdef inline void __uv_stream_on_read_impl(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc._protocol_data_received(loop._recv_buffer[:nread]) + sc.context.run(sc._protocol_data_received, loop._recv_buffer[:nread]) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 @@ -911,7 +911,7 @@ cdef void __uv_stream_buffered_alloc(uv.uv_handle_t* stream, sc._read_pybuf_acquired = 0 try: - buf = sc._protocol_get_buffer(suggested_size) + buf = sc.context.run(sc._protocol_get_buffer, suggested_size) PyObject_GetBuffer(buf, pybuf, PyBUF_WRITABLE) got_buf = 1 except BaseException as exc: @@ -976,7 +976,7 @@ cdef void __uv_stream_buffered_on_read(uv.uv_stream_t* stream, if UVLOOP_DEBUG: loop._debug_stream_read_cb_total += 1 - sc._protocol_buffer_updated(nread) + sc.context.run(sc._protocol_buffer_updated, nread) except BaseException as exc: if UVLOOP_DEBUG: loop._debug_stream_read_cb_errors_total += 1 diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index b2ab1887..a004efd9 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -22,4 +22,5 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline listen(self) cdef inline _on_listen(self) - cdef UVStream _make_new_transport(self, object protocol, object waiter) + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context) diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 7b2258dd..921c3565 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -53,6 +53,8 @@ cdef class UVStreamServer(UVSocketHandle): if self.opened != 1: raise RuntimeError('unopened TCPServer') + self.context = Context_CopyCurrent() + err = uv.uv_listen( self._handle, self.backlog, __uv_streamserver_on_listen) @@ -64,10 +66,10 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _on_listen(self): cdef UVStream client - protocol = self.protocol_factory() + protocol = self.context.run(self.protocol_factory) if self.ssl is None: - client = self._make_new_transport(protocol, None) + client = self._make_new_transport(protocol, None, self.context) else: waiter = self._loop._new_future() @@ -80,7 +82,7 @@ cdef class UVStreamServer(UVSocketHandle): ssl_handshake_timeout=self.ssl_handshake_timeout, ssl_shutdown_timeout=self.ssl_shutdown_timeout) - client = self._make_new_transport(ssl_protocol, None) + client = self._make_new_transport(ssl_protocol, None, self.context) waiter.add_done_callback( ft_partial(self.__on_ssl_connected, client)) @@ -109,7 +111,8 @@ cdef class UVStreamServer(UVSocketHandle): cdef inline _mark_as_open(self): self.opened = 1 - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): raise NotImplementedError def __on_ssl_connected(self, transport, fut): diff --git a/uvloop/handles/tcp.pxd b/uvloop/handles/tcp.pxd index 69284db3..8d388ef0 100644 --- a/uvloop/handles/tcp.pxd +++ b/uvloop/handles/tcp.pxd @@ -23,4 +23,4 @@ cdef class TCPTransport(UVStream): @staticmethod cdef TCPTransport new(Loop loop, object protocol, Server server, - object waiter) + object waiter, object context) diff --git a/uvloop/handles/tcp.pyx b/uvloop/handles/tcp.pyx index 363e0cd8..8c65f69e 100644 --- a/uvloop/handles/tcp.pyx +++ b/uvloop/handles/tcp.pyx @@ -91,9 +91,11 @@ cdef class TCPServer(UVStreamServer): else: self._mark_as_open() - cdef UVStream _make_new_transport(self, object protocol, object waiter): + cdef UVStream _make_new_transport(self, object protocol, object waiter, + object context): cdef TCPTransport tr - tr = TCPTransport.new(self._loop, protocol, self._server, waiter) + tr = TCPTransport.new(self._loop, protocol, self._server, waiter, + context) return tr @@ -102,11 +104,11 @@ cdef class TCPTransport(UVStream): @staticmethod cdef TCPTransport new(Loop loop, object protocol, Server server, - object waiter): + object waiter, object context): cdef TCPTransport handle handle = TCPTransport.__new__(TCPTransport) - handle._init(loop, protocol, server, waiter) + handle._init(loop, protocol, server, waiter, context) __tcp_init_uv_handle(handle, loop, uv.AF_UNSPEC) handle.__peername_set = 0 handle.__sockname_set = 0 diff --git a/uvloop/handles/udp.pxd b/uvloop/handles/udp.pxd index c4d2f3de..daa9a1be 100644 --- a/uvloop/handles/udp.pxd +++ b/uvloop/handles/udp.pxd @@ -19,4 +19,4 @@ cdef class UDPTransport(UVBaseTransport): cdef _send(self, object data, object addr) cdef _on_receive(self, bytes data, object exc, object addr) - cdef _on_sent(self, object exc) + cdef _on_sent(self, object exc, object context=*) diff --git a/uvloop/handles/udp.pyx b/uvloop/handles/udp.pyx index 4d590a65..82dabbf4 100644 --- a/uvloop/handles/udp.pyx +++ b/uvloop/handles/udp.pyx @@ -56,6 +56,7 @@ cdef class UDPTransport(UVBaseTransport): self._family = uv.AF_UNSPEC self.__receiving = 0 self._address = None + self.context = Context_CopyCurrent() cdef _init(self, Loop loop, unsigned int family): cdef int err @@ -252,18 +253,20 @@ cdef class UDPTransport(UVBaseTransport): exc = convert_error(err) self._fatal_error(exc, True) else: - self._on_sent(None) + self._on_sent(None, self.context.copy()) cdef _on_receive(self, bytes data, object exc, object addr): if exc is None: - self._protocol.datagram_received(data, addr) + self.context.run(self._protocol.datagram_received, data, addr) else: - self._protocol.error_received(exc) + self.context.run(self._protocol.error_received, exc) - cdef _on_sent(self, object exc): + cdef _on_sent(self, object exc, object context=None): if exc is not None: if isinstance(exc, OSError): - self._protocol.error_received(exc) + if context is None: + context = self.context + context.run(self._protocol.error_received, exc) else: self._fatal_error( exc, False, 'Fatal write error on datagram transport') diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 61191fb6..a7b24d60 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -157,7 +157,7 @@ cdef class Loop: self.handler_idle = UVIdle.new( self, new_MethodHandle( - self, "loop._on_idle", self._on_idle, self)) + self, "loop._on_idle", self._on_idle, None, self)) # Needed to call `UVStream._exec_write` for writes scheduled # during `Protocol.data_received`. @@ -165,7 +165,7 @@ cdef class Loop: self, new_MethodHandle( self, "loop._exec_queued_writes", - self._exec_queued_writes, self)) + self._exec_queued_writes, None, self)) self._signals = set() self._ssock = self._csock = None @@ -288,6 +288,7 @@ cdef class Loop: self, "Loop._read_from_self", self._read_from_self, + None, self)) self._listening_signals = True @@ -1009,6 +1010,7 @@ cdef class Loop: self, "Loop._sock_sendall", self._sock_sendall, + None, self, fut, sock, data) @@ -1049,6 +1051,7 @@ cdef class Loop: self, "Loop._sock_connect", self._sock_connect_cb, + None, self, fut, sock, address) @@ -1311,6 +1314,7 @@ cdef class Loop: self, "Loop._stop", self._stop, + None, self, None)) @@ -1531,8 +1535,11 @@ cdef class Loop: f'sslcontext is expected to be an instance of ssl.SSLContext, ' f'got {sslcontext!r}') - if not isinstance(transport, (TCPTransport, UnixTransport, - _SSLProtocolTransport)): + if isinstance(transport, (TCPTransport, UnixTransport)): + context = (transport).context + elif isinstance(transport, _SSLProtocolTransport): + context = (<_SSLProtocolTransport>transport).context + else: raise TypeError( f'transport {transport!r} is not supported by start_tls()') @@ -1549,9 +1556,12 @@ cdef class Loop: transport.pause_reading() transport.set_protocol(ssl_protocol) - conmade_cb = self.call_soon(ssl_protocol.connection_made, transport) + conmade_cb = self.call_soon(ssl_protocol.connection_made, transport, + context=context) + # transport.resume_reading() will use the right context + # (transport.context) to call e.g. data_received() resume_cb = self.call_soon(transport.resume_reading) - app_transport = ssl_protocol._get_app_transport() + app_transport = ssl_protocol._get_app_transport(context) try: await waiter @@ -1825,6 +1835,7 @@ cdef class Loop: app_protocol = protocol = protocol_factory() ssl_waiter = None + context = Context_CopyCurrent() if ssl: if server_hostname is None: if not host: @@ -1917,7 +1928,8 @@ cdef class Loop: tr = None try: waiter = self._new_future() - tr = TCPTransport.new(self, protocol, None, waiter) + tr = TCPTransport.new(self, protocol, None, waiter, + context) if lai is not NULL: lai_iter = lai @@ -1977,7 +1989,7 @@ cdef class Loop: sock.setblocking(False) waiter = self._new_future() - tr = TCPTransport.new(self, protocol, None, waiter) + tr = TCPTransport.new(self, protocol, None, waiter, context) try: # libuv will make socket non-blocking tr._open(sock.fileno()) @@ -1997,7 +2009,7 @@ cdef class Loop: tr._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(context) try: await ssl_waiter except (KeyboardInterrupt, SystemExit): @@ -2153,6 +2165,7 @@ cdef class Loop: app_protocol = protocol = protocol_factory() ssl_waiter = None + context = Context_CopyCurrent() if ssl: if server_hostname is None: raise ValueError('You must set server_hostname ' @@ -2186,7 +2199,7 @@ cdef class Loop: path = PyUnicode_EncodeFSDefault(path) waiter = self._new_future() - tr = UnixTransport.new(self, protocol, None, waiter) + tr = UnixTransport.new(self, protocol, None, waiter, context) tr.connect(path) try: await waiter @@ -2210,7 +2223,7 @@ cdef class Loop: sock.setblocking(False) waiter = self._new_future() - tr = UnixTransport.new(self, protocol, None, waiter) + tr = UnixTransport.new(self, protocol, None, waiter, context) try: tr._open(sock.fileno()) tr._init_protocol() @@ -2224,7 +2237,7 @@ cdef class Loop: tr._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(Context_CopyCurrent()) try: await ssl_waiter except (KeyboardInterrupt, SystemExit): @@ -2397,6 +2410,7 @@ cdef class Loop: self, "Loop._sock_recv", self._sock_recv, + None, self, fut, sock, n) @@ -2423,6 +2437,7 @@ cdef class Loop: self, "Loop._sock_recv_into", self._sock_recv_into, + None, self, fut, sock, buf) @@ -2474,6 +2489,7 @@ cdef class Loop: self, "Loop._sock_sendall", self._sock_sendall, + None, self, fut, sock, data) @@ -2504,6 +2520,7 @@ cdef class Loop: self, "Loop._sock_accept", self._sock_accept, + None, self, fut, sock) @@ -2569,6 +2586,7 @@ cdef class Loop: app_protocol = protocol_factory() waiter = self._new_future() transport_waiter = None + context = Context_CopyCurrent() if ssl is None: protocol = app_protocol @@ -2584,10 +2602,10 @@ cdef class Loop: if sock.family == uv.AF_UNIX: transport = UnixTransport.new( - self, protocol, None, transport_waiter) + self, protocol, None, transport_waiter, context) elif sock.family in (uv.AF_INET, uv.AF_INET6): transport = TCPTransport.new( - self, protocol, None, transport_waiter) + self, protocol, None, transport_waiter, context) if transport is None: raise ValueError( @@ -2598,7 +2616,7 @@ cdef class Loop: transport._attach_fileobj(sock) if ssl: - app_transport = protocol._get_app_transport() + app_transport = protocol._get_app_transport(context) try: await waiter except (KeyboardInterrupt, SystemExit): diff --git a/uvloop/sslproto.pxd b/uvloop/sslproto.pxd index a6daa5c0..3da10f00 100644 --- a/uvloop/sslproto.pxd +++ b/uvloop/sslproto.pxd @@ -27,6 +27,7 @@ cdef class _SSLProtocolTransport: Loop _loop SSLProtocol _ssl_protocol bint _closed + object context cdef class SSLProtocol: @@ -97,17 +98,17 @@ cdef class SSLProtocol: # Shutdown flow - cdef _start_shutdown(self) + cdef _start_shutdown(self, object context=*) cdef _check_shutdown_timeout(self) - cdef _do_read_into_void(self) - cdef _do_flush(self) - cdef _do_shutdown(self) + cdef _do_read_into_void(self, object context) + cdef _do_flush(self, object context=*) + cdef _do_shutdown(self, object context=*) cdef _on_shutdown_complete(self, shutdown_exc) cdef _abort(self, exc) # Outgoing flow - cdef _write_appdata(self, list_of_data) + cdef _write_appdata(self, list_of_data, object context) cdef _do_write(self) cdef _process_outgoing(self) @@ -116,18 +117,18 @@ cdef class SSLProtocol: cdef _do_read(self) cdef _do_read__buffered(self) cdef _do_read__copied(self) - cdef _call_eof_received(self) + cdef _call_eof_received(self, object context=*) # Flow control for writes from APP socket - cdef _control_app_writing(self) + cdef _control_app_writing(self, object context=*) cdef size_t _get_write_buffer_size(self) cdef _set_write_buffer_limits(self, high=*, low=*) # Flow control for reads to APP socket cdef _pause_reading(self) - cdef _resume_reading(self) + cdef _resume_reading(self, object context) # Flow control for reads from SSL socket diff --git a/uvloop/sslproto.pyx b/uvloop/sslproto.pyx index ac87d499..d1f976e3 100644 --- a/uvloop/sslproto.pyx +++ b/uvloop/sslproto.pyx @@ -17,11 +17,14 @@ cdef class _SSLProtocolTransport: # TODO: # _sendfile_compatible = constants._SendfileMode.FALLBACK - def __cinit__(self, Loop loop, ssl_protocol): + def __cinit__(self, Loop loop, ssl_protocol, context): self._loop = loop # SSLProtocol instance self._ssl_protocol = ssl_protocol self._closed = False + if context is None: + context = Context_CopyCurrent() + self.context = context def get_extra_info(self, name, default=None): """Get optional transport information.""" @@ -45,7 +48,7 @@ cdef class _SSLProtocolTransport: with None as its argument. """ self._closed = True - self._ssl_protocol._start_shutdown() + self._ssl_protocol._start_shutdown(self.context.copy()) def __dealloc__(self): if not self._closed: @@ -71,7 +74,7 @@ cdef class _SSLProtocolTransport: Data received will once again be passed to the protocol's data_received() method. """ - self._ssl_protocol._resume_reading() + self._ssl_protocol._resume_reading(self.context.copy()) def set_write_buffer_limits(self, high=None, low=None): """Set the high- and low-water limits for write flow control. @@ -93,7 +96,7 @@ cdef class _SSLProtocolTransport: concurrently. """ self._ssl_protocol._set_write_buffer_limits(high, low) - self._ssl_protocol._control_app_writing() + self._ssl_protocol._control_app_writing(self.context.copy()) def get_write_buffer_limits(self): return (self._ssl_protocol._outgoing_low_water, @@ -149,7 +152,7 @@ cdef class _SSLProtocolTransport: f"got {type(data).__name__}") if not data: return - self._ssl_protocol._write_appdata((data,)) + self._ssl_protocol._write_appdata((data,), self.context.copy()) def writelines(self, list_of_data): """Write a list (or any iterable) of data bytes to the transport. @@ -157,7 +160,7 @@ cdef class _SSLProtocolTransport: The default implementation concatenates the arguments and calls write() on the result. """ - self._ssl_protocol._write_appdata(list_of_data) + self._ssl_protocol._write_appdata(list_of_data, self.context.copy()) def write_eof(self): """Close the write end after flushing buffered data. @@ -304,11 +307,12 @@ cdef class SSLProtocol: self._waiter.set_result(None) self._waiter = None - def _get_app_transport(self): + def _get_app_transport(self, context=None): if self._app_transport is None: if self._app_transport_created: raise RuntimeError('Creating _SSLProtocolTransport twice') - self._app_transport = _SSLProtocolTransport(self._loop, self) + self._app_transport = _SSLProtocolTransport(self._loop, self, + context) self._app_transport_created = True return self._app_transport @@ -540,9 +544,12 @@ cdef class SSLProtocol: # Shutdown flow - cdef _start_shutdown(self): + cdef _start_shutdown(self, object context=None): if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): return + # we don't need the context for _abort or the timeout, because + # TCP transport._force_close() should be able to call + # connection_lost() in the right context if self._app_transport is not None: self._app_transport._closed = True if self._state == DO_HANDSHAKE: @@ -552,14 +559,14 @@ cdef class SSLProtocol: self._shutdown_timeout_handle = \ self._loop.call_later(self._ssl_shutdown_timeout, lambda: self._check_shutdown_timeout()) - self._do_flush() + self._do_flush(context) cdef _check_shutdown_timeout(self): if self._state in (FLUSHING, SHUTDOWN): self._transport._force_close( aio_TimeoutError('SSL shutdown timed out')) - cdef _do_read_into_void(self): + cdef _do_read_into_void(self, object context): """Consume and discard incoming application data. If close_notify is received for the first time, call eof_received. @@ -576,9 +583,9 @@ cdef class SSLProtocol: except ssl_SSLZeroReturnError: close_notify = True if close_notify: - self._call_eof_received() + self._call_eof_received(context) - cdef _do_flush(self): + cdef _do_flush(self, object context=None): """Flush the write backlog, discarding new data received. We don't send close_notify in FLUSHING because we still want to send @@ -587,7 +594,7 @@ cdef class SSLProtocol: in FLUSHING, as we could fully manage the flow control internally. """ try: - self._do_read_into_void() + self._do_read_into_void(context) self._do_write() self._process_outgoing() self._control_ssl_reading() @@ -596,13 +603,13 @@ cdef class SSLProtocol: else: if not self._get_write_buffer_size(): self._set_state(SHUTDOWN) - self._do_shutdown() + self._do_shutdown(context) - cdef _do_shutdown(self): + cdef _do_shutdown(self, object context=None): """Send close_notify and wait for the same from the peer.""" try: # we must skip all application data (if any) before unwrap - self._do_read_into_void() + self._do_read_into_void(context) try: self._sslobj.unwrap() except ssl_SSLAgainErrors as exc: @@ -619,6 +626,8 @@ cdef class SSLProtocol: self._shutdown_timeout_handle.cancel() self._shutdown_timeout_handle = None + # we don't need the context here because TCP transport.close() should + # be able to call connection_made() in the right context if shutdown_exc: self._fatal_error(shutdown_exc, 'Error occurred during shutdown') else: @@ -631,7 +640,7 @@ cdef class SSLProtocol: # Outgoing flow - cdef _write_appdata(self, list_of_data): + cdef _write_appdata(self, list_of_data, object context): if self._state in (FLUSHING, SHUTDOWN, UNWRAPPED): if self._conn_lost >= LOG_THRESHOLD_FOR_CONNLOST_WRITES: aio_logger.warning('SSL connection is closed') @@ -646,7 +655,7 @@ cdef class SSLProtocol: if self._state == WRAPPED: self._do_write() self._process_outgoing() - self._control_app_writing() + self._control_app_writing(context) except Exception as ex: self._fatal_error(ex, 'Fatal error on SSL protocol') @@ -730,6 +739,7 @@ cdef class SSLProtocol: new_MethodHandle(self._loop, "SSLProtocol._do_read", self._do_read, + None, # current context is good self)) except ssl_SSLAgainErrors as exc: pass @@ -774,11 +784,17 @@ cdef class SSLProtocol: self._call_eof_received() self._start_shutdown() - cdef _call_eof_received(self): + cdef _call_eof_received(self, object context=None): if self._app_state == STATE_CON_MADE: self._app_state = STATE_EOF try: - keep_open = self._app_protocol.eof_received() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like buffer_updated() + keep_open = self._app_protocol.eof_received() + else: + keep_open = context.run(self._app_protocol.eof_received) except (KeyboardInterrupt, SystemExit): raise except BaseException as ex: @@ -790,12 +806,18 @@ cdef class SSLProtocol: # Flow control for writes from APP socket - cdef _control_app_writing(self): + cdef _control_app_writing(self, object context=None): cdef size_t size = self._get_write_buffer_size() if size >= self._outgoing_high_water and not self._app_writing_paused: self._app_writing_paused = True try: - self._app_protocol.pause_writing() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like buffer_updated() + self._app_protocol.pause_writing() + else: + context.run(self._app_protocol.pause_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -808,7 +830,13 @@ cdef class SSLProtocol: elif size <= self._outgoing_low_water and self._app_writing_paused: self._app_writing_paused = False try: - self._app_protocol.resume_writing() + if context is None: + # If the caller didn't provide a context, we assume the + # caller is already in the right context, which is usually + # inside the upstream callbacks like resume_writing() + self._app_protocol.resume_writing() + else: + context.run(self._app_protocol.resume_writing) except (KeyboardInterrupt, SystemExit): raise except BaseException as exc: @@ -833,7 +861,7 @@ cdef class SSLProtocol: cdef _pause_reading(self): self._app_reading_paused = True - cdef _resume_reading(self): + cdef _resume_reading(self, object context): if self._app_reading_paused: self._app_reading_paused = False if self._state == WRAPPED: @@ -841,6 +869,7 @@ cdef class SSLProtocol: new_MethodHandle(self._loop, "SSLProtocol._do_read", self._do_read, + context, self)) # Flow control for reads from SSL socket @@ -890,7 +919,9 @@ cdef class SSLProtocol: self._do_shutdown() cdef _fatal_error(self, exc, message='Fatal error on transport'): - if self._transport: + if self._app_transport: + self._app_transport._force_close(exc) + elif self._transport: self._transport._force_close(exc) if isinstance(exc, OSError):