diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 8036cce1..16232159 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -189,6 +189,91 @@ def test_socket_sync_remove_and_immediately_close(self): self.assertEqual(sock.fileno(), -1) self.loop.run_until_complete(asyncio.sleep(0.01, loop=self.loop)) + def test_sock_cancel_add_reader_race(self): + srv_sock_conn = None + + async def server(): + nonlocal srv_sock_conn + sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_server.setblocking(False) + with sock_server: + sock_server.bind(('127.0.0.1', 0)) + sock_server.listen() + fut = asyncio.ensure_future( + client(sock_server.getsockname()), loop=self.loop) + srv_sock_conn, _ = await self.loop.sock_accept(sock_server) + srv_sock_conn.setsockopt( + socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + with srv_sock_conn: + await fut + + async def client(addr): + sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_client.setblocking(False) + with sock_client: + await self.loop.sock_connect(sock_client, addr) + _, pending_read_futs = await asyncio.wait( + [self.loop.sock_recv(sock_client, 1)], + timeout=1, loop=self.loop) + + async def send_server_data(): + # Wait a little bit to let reader future cancel and + # schedule the removal of the reader callback. Right after + # "rfut.cancel()" we will call "loop.sock_recv()", which + # will add a reader. This will make a race between + # remove- and add-reader. + await asyncio.sleep(0.1, loop=self.loop) + await self.loop.sock_sendall(srv_sock_conn, b'1') + self.loop.create_task(send_server_data()) + + for rfut in pending_read_futs: + rfut.cancel() + + data = await self.loop.sock_recv(sock_client, 1) + + self.assertEqual(data, b'1') + + self.loop.run_until_complete(server()) + + def test_sock_send_before_cancel(self): + srv_sock_conn = None + + async def server(): + nonlocal srv_sock_conn + sock_server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_server.setblocking(False) + with sock_server: + sock_server.bind(('127.0.0.1', 0)) + sock_server.listen() + fut = asyncio.ensure_future( + client(sock_server.getsockname()), loop=self.loop) + srv_sock_conn, _ = await self.loop.sock_accept(sock_server) + with srv_sock_conn: + await fut + + async def client(addr): + await asyncio.sleep(0.01, loop=self.loop) + sock_client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock_client.setblocking(False) + with sock_client: + await self.loop.sock_connect(sock_client, addr) + _, pending_read_futs = await asyncio.wait( + [self.loop.sock_recv(sock_client, 1)], + timeout=1, loop=self.loop) + + # server can send the data in a random time, even before + # the previous result future has cancelled. + await self.loop.sock_sendall(srv_sock_conn, b'1') + + for rfut in pending_read_futs: + rfut.cancel() + + data = await self.loop.sock_recv(sock_client, 1) + + self.assertEqual(data, b'1') + + self.loop.run_until_complete(server()) + class TestUVSockets(_TestSockets, tb.UVTestCase): diff --git a/uvloop/handles/poll.pxd b/uvloop/handles/poll.pxd index ac6100ea..d07030b5 100644 --- a/uvloop/handles/poll.pxd +++ b/uvloop/handles/poll.pxd @@ -13,6 +13,8 @@ cdef class UVPoll(UVHandle): cdef int is_active(self) cdef is_reading(self) + cdef is_writing(self) + cdef start_reading(self, Handle callback) cdef start_writing(self, Handle callback) cdef stop_reading(self) diff --git a/uvloop/handles/poll.pyx b/uvloop/handles/poll.pyx index 8640eced..941809a6 100644 --- a/uvloop/handles/poll.pyx +++ b/uvloop/handles/poll.pyx @@ -87,6 +87,9 @@ cdef class UVPoll(UVHandle): cdef is_reading(self): return self._is_alive() and self.reading_handle is not None + cdef is_writing(self): + return self._is_alive() and self.writing_handle is not None + cdef start_reading(self, Handle callback): cdef: int mask = 0 diff --git a/uvloop/loop.pxd b/uvloop/loop.pxd index 6640abea..3086b235 100644 --- a/uvloop/loop.pxd +++ b/uvloop/loop.pxd @@ -177,12 +177,12 @@ cdef class Loop: cdef _track_process(self, UVProcess proc) cdef _untrack_process(self, UVProcess proc) - cdef _new_reader_future(self, sock) - cdef _new_writer_future(self, sock) cdef _add_reader(self, fd, Handle handle) + cdef _has_reader(self, fd) cdef _remove_reader(self, fd) cdef _add_writer(self, fd, Handle handle) + cdef _has_writer(self, fd) cdef _remove_writer(self, fd) cdef _sock_recv(self, fut, sock, n) diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index 67938b83..5a82c105 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -742,6 +742,20 @@ cdef class Loop: return result + cdef _has_reader(self, fileobj): + cdef: + UVPoll poll + + self._check_closed() + fd = self._fileobj_to_fd(fileobj) + + try: + poll = (self._polls[fd]) + except KeyError: + return False + + return poll.is_reading() + cdef _add_writer(self, fileobj, Handle handle): cdef: UVPoll poll @@ -791,6 +805,20 @@ cdef class Loop: return result + cdef _has_writer(self, fileobj): + cdef: + UVPoll poll + + self._check_closed() + fd = self._fileobj_to_fd(fileobj) + + try: + poll = (self._polls[fd]) + except KeyError: + return False + + return poll.is_writing() + cdef _getaddrinfo(self, object host, object port, int family, int type, int proto, int flags, @@ -845,35 +873,17 @@ cdef class Loop: nr.query(addr, flags) return fut - cdef _new_reader_future(self, sock): - def _on_cancel(fut): - # Check if the future was cancelled and if the socket - # is still open, i.e. - # - # loop.remove_reader(sock) - # sock.close() - # fut.cancel() - # - # wasn't called by the user. - if fut.cancelled() and sock.fileno() != -1: - self._remove_reader(sock) - - fut = self._new_future() - fut.add_done_callback(_on_cancel) - return fut - - cdef _new_writer_future(self, sock): - def _on_cancel(fut): - if fut.cancelled() and sock.fileno() != -1: - self._remove_writer(sock) - - fut = self._new_future() - fut.add_done_callback(_on_cancel) - return fut - cdef _sock_recv(self, fut, sock, n): - cdef: - Handle handle + if UVLOOP_DEBUG: + if fut.cancelled(): + # Shouldn't happen with _SyncSocketReaderFuture. + raise RuntimeError( + f'_sock_recv is called on a cancelled Future') + + if not self._has_reader(sock): + raise RuntimeError( + f'socket {sock!r} does not have a reader ' + f'in the _sock_recv callback') try: data = sock.recv(n) @@ -889,8 +899,16 @@ cdef class Loop: self._remove_reader(sock) cdef _sock_recv_into(self, fut, sock, buf): - cdef: - Handle handle + if UVLOOP_DEBUG: + if fut.cancelled(): + # Shouldn't happen with _SyncSocketReaderFuture. + raise RuntimeError( + f'_sock_recv_into is called on a cancelled Future') + + if not self._has_reader(sock): + raise RuntimeError( + f'socket {sock!r} does not have a reader ' + f'in the _sock_recv_into callback') try: data = sock.recv_into(buf) @@ -910,6 +928,17 @@ cdef class Loop: Handle handle int n + if UVLOOP_DEBUG: + if fut.cancelled(): + # Shouldn't happen with _SyncSocketReaderFuture. + raise RuntimeError( + f'_sock_sendall is called on a cancelled Future') + + if not self._has_writer(sock): + raise RuntimeError( + f'socket {sock!r} does not have a writer ' + f'in the _sock_sendall callback') + try: n = sock.send(data) except (BlockingIOError, InterruptedError): @@ -940,9 +969,6 @@ cdef class Loop: self._add_writer(sock, handle) cdef _sock_accept(self, fut, sock): - cdef: - Handle handle - try: conn, address = sock.accept() conn.setblocking(False) @@ -2217,7 +2243,7 @@ cdef class Loop: if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - fut = self._new_reader_future(sock) + fut = _SyncSocketReaderFuture(sock, self) handle = new_MethodHandle3( self, "Loop._sock_recv", @@ -2243,7 +2269,7 @@ cdef class Loop: if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - fut = self._new_reader_future(sock) + fut = _SyncSocketReaderFuture(sock, self) handle = new_MethodHandle3( self, "Loop._sock_recv_into", @@ -2294,7 +2320,7 @@ cdef class Loop: data = memoryview(data) data = data[n:] - fut = self._new_writer_future(sock) + fut = _SyncSocketWriterFuture(sock, self) handle = new_MethodHandle3( self, "Loop._sock_sendall", @@ -2324,7 +2350,7 @@ cdef class Loop: if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") - fut = self._new_reader_future(sock) + fut = _SyncSocketReaderFuture(sock, self) handle = new_MethodHandle2( self, "Loop._sock_accept", @@ -2908,6 +2934,36 @@ cdef inline void __loop_free_buffer(Loop loop): loop._recv_buffer_in_use = 0 +class _SyncSocketReaderFuture(aio_Future): + + def __init__(self, sock, loop): + aio_Future.__init__(self, loop=loop) + self.__sock = sock + self.__loop = loop + + def cancel(self): + if self.__sock is not None and self.__sock.fileno() != -1: + self.__loop.remove_reader(self.__sock) + self.__sock = None + + aio_Future.cancel(self) + + +class _SyncSocketWriterFuture(aio_Future): + + def __init__(self, sock, loop): + aio_Future.__init__(self, loop=loop) + self.__sock = sock + self.__loop = loop + + def cancel(self): + if self.__sock is not None and self.__sock.fileno() != -1: + self.__loop.remove_writer(self.__sock) + self.__sock = None + + aio_Future.cancel(self) + + include "cbhandles.pyx" include "pseudosock.pyx"