Skip to content

Commit

Permalink
pythongh-111246: Don't remove stolen Unix socket address (python#111246)
Browse files Browse the repository at this point in the history
We only want to clean up *our* socket, so try to determine if we still
own this address or if something else has replaced it.
  • Loading branch information
CendioOssman committed Oct 30, 2023
1 parent 21e9d8c commit df3b74e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
3 changes: 2 additions & 1 deletion Doc/library/asyncio-eventloop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,8 @@ Creating network servers
.. versionchanged:: 3.13

The Unix socket will automatically be removed from the filesystem
when the server is closed.
when the server is closed, unless the socket has been replaced
after the server has been created.


.. coroutinemethod:: loop.connect_accepted_socket(protocol_factory, \
Expand Down
21 changes: 14 additions & 7 deletions Lib/asyncio/unix_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class _UnixSelectorEventLoop(selector_events.BaseSelectorEventLoop):
def __init__(self, selector=None):
super().__init__(selector)
self._signal_handlers = {}
self._unix_server_sockets = {}

def close(self):
super().close()
Expand Down Expand Up @@ -340,6 +341,14 @@ async def create_unix_server(
raise ValueError(
f'A UNIX Domain Stream Socket was expected, got {sock!r}')

path = sock.getsockname()
# Check for abstract socket. `str` and `bytes` paths are supported.
if path[0] not in (0, '\x00'):
try:
self._unix_server_sockets[sock] = os.stat(path).st_ino
except FileNotFoundError:
pass

sock.setblocking(False)
server = base_events.Server(self, [sock], protocol_factory,
ssl, backlog, ssl_handshake_timeout,
Expand Down Expand Up @@ -462,21 +471,19 @@ def cb(fut):

def _stop_serving(self, sock):
# Is this a unix socket that needs cleanup?
if sock.family == socket.AF_UNIX:
if sock in self._unix_server_sockets:
path = sock.getsockname()
if path == '':
path = None
# Check for abstract socket. `str` and `bytes` paths are supported.
elif path[0] in (0, '\x00'):
path = None
else:
path = None

super()._stop_serving(sock)

if path is not None:
prev_ino = self._unix_server_sockets[sock]
del self._unix_server_sockets[sock]
try:
os.unlink(path)
if os.stat(path).st_ino == prev_ino:
os.unlink(path)
except FileNotFoundError:
pass
except OSError as err:
Expand Down
15 changes: 15 additions & 0 deletions Lib/test/test_asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,21 @@ async def serve(*args):

srv.close()

@socket_helper.skip_unless_bind_unix_socket
async def test_unix_server_cleanup_replaced(self):
with test_utils.unix_socket_path() as addr:
async def serve(*args):
pass

srv = await asyncio.start_unix_server(serve, addr)

os.unlink(addr)
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.bind(addr)

srv.close()
self.assertTrue(os.path.exists(addr))


@unittest.skipUnless(hasattr(asyncio, 'ProactorEventLoop'), 'Windows only')
class ProactorStartServerTests(BaseStartServer, unittest.TestCase):
Expand Down

0 comments on commit df3b74e

Please sign in to comment.