diff --git a/tests/test_context.py b/tests/test_context.py index 0f340cb8..44eae3f6 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -4,6 +4,7 @@ import sys import unittest import weakref +import socket from uvloop import _testbase as tb @@ -139,6 +140,40 @@ async def main(): del tracked self.assertIsNone(ref()) + @unittest.skipUnless(PY37, 'requires Python 3.7') + def test_create_server_protocol_factory_context(self): + import contextvars + cvar = contextvars.ContextVar('cvar', default='outer') + factory_called_future = self.loop.create_future() + + def factory(): + try: + self.assertEqual(cvar.get(), 'inner') + except Exception as e: + factory_called_future.set_exception(e) + else: + factory_called_future.set_result(None) + + return asyncio.Protocol() + + async def test(): + cvar.set('inner') + port = tb.find_free_port() + srv = await self.loop.create_server(factory, '127.0.0.1', port) + + s = socket.socket(socket.AF_INET) + with s: + s.setblocking(False) + await self.loop.sock_connect(s, ('127.0.0.1', port)) + + try: + await factory_called_future + finally: + srv.close() + await srv.wait_closed() + + self.loop.run_until_complete(test()) + class Test_UV_Context(_ContextBaseTests, tb.UVTestCase): diff --git a/uvloop/handles/streamserver.pxd b/uvloop/handles/streamserver.pxd index b2ab1887..e2093316 100644 --- a/uvloop/handles/streamserver.pxd +++ b/uvloop/handles/streamserver.pxd @@ -7,6 +7,7 @@ cdef class UVStreamServer(UVSocketHandle): object protocol_factory bint opened Server _server + object listen_context # All "inline" methods are final diff --git a/uvloop/handles/streamserver.pyx b/uvloop/handles/streamserver.pyx index 7b2258dd..418ed561 100644 --- a/uvloop/handles/streamserver.pyx +++ b/uvloop/handles/streamserver.pyx @@ -8,6 +8,7 @@ cdef class UVStreamServer(UVSocketHandle): self.ssl_handshake_timeout = None self.ssl_shutdown_timeout = None self.protocol_factory = None + self.listen_context = None cdef inline _init(self, Loop loop, object protocol_factory, Server server, @@ -53,6 +54,9 @@ cdef class UVStreamServer(UVSocketHandle): if self.opened != 1: raise RuntimeError('unopened TCPServer') + if PY37: + self.listen_context = Context_CopyCurrent() + err = uv.uv_listen( self._handle, self.backlog, __uv_streamserver_on_listen) @@ -140,6 +144,11 @@ cdef void __uv_streamserver_on_listen(uv.uv_stream_t* handle, return try: + if PY37: + Context_Enter(stream.listen_context) stream._on_listen() except BaseException as exc: stream._error(exc, False) + finally: + if PY37: + Context_Exit(stream.listen_context)