diff --git a/CHANGES/8163.bugfix.rst b/CHANGES/8163.bugfix.rst new file mode 100644 index 00000000000..8bfb10260c6 --- /dev/null +++ b/CHANGES/8163.bugfix.rst @@ -0,0 +1,5 @@ +Improved the DNS resolution performance on cache hit +-- by :user:`bdraco`. + +This is achieved by avoiding an :mod:`asyncio` task creation +in this case. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 3b9841dd094..f95ebe84c66 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -825,6 +825,7 @@ def clear_dns_cache( async def _resolve_host( self, host: str, port: int, traces: Optional[List["Trace"]] = None ) -> List[Dict[str, Any]]: + """Resolve host and return list of addresses.""" if is_ip_address(host): return [ { @@ -852,8 +853,7 @@ async def _resolve_host( return res key = (host, port) - - if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)): + if key in self._cached_hosts and not self._cached_hosts.expired(key): # get result early, before any await (#4014) result = self._cached_hosts.next_addrs(key) @@ -862,6 +862,39 @@ async def _resolve_host( await trace.send_dns_cache_hit(host) return result + # + # If multiple connectors are resolving the same host, we wait + # for the first one to resolve and then use the result for all of them. + # We use a throttle event to ensure that we only resolve the host once + # and then use the result for all the waiters. + # + # In this case we need to create a task to ensure that we can shield + # the task from cancellation as cancelling this lookup should not cancel + # the underlying lookup or else the cancel event will get broadcast to + # all the waiters across all connections. + # + resolved_host_task = asyncio.create_task( + self._resolve_host_with_throttle(key, host, port, traces) + ) + try: + return await asyncio.shield(resolved_host_task) + except asyncio.CancelledError: + + def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + with suppress(Exception, asyncio.CancelledError): + fut.result() + + resolved_host_task.add_done_callback(drop_exception) + raise + + async def _resolve_host_with_throttle( + self, + key: Tuple[str, int], + host: str, + port: int, + traces: Optional[List["Trace"]], + ) -> List[Dict[str, Any]]: + """Resolve host with a dns events throttle.""" if key in self._throttle_dns_events: # get event early, before any await (#4014) event = self._throttle_dns_events[key] @@ -1163,22 +1196,11 @@ async def _create_direct_connection( host = host.rstrip(".") + "." port = req.port assert port is not None - host_resolved = asyncio.ensure_future( - self._resolve_host(host, port, traces=traces), loop=self._loop - ) try: # Cancelling this lookup should not cancel the underlying lookup # or else the cancel event will get broadcast to all the waiters # across all connections. - hosts = await asyncio.shield(host_resolved) - except asyncio.CancelledError: - - def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: - with suppress(Exception, asyncio.CancelledError): - fut.result() - - host_resolved.add_done_callback(drop_exception) - raise + hosts = await self._resolve_host(host, port, traces=traces) except OSError as exc: if exc.errno is None and isinstance(exc, asyncio.TimeoutError): raise diff --git a/tests/test_connector.py b/tests/test_connector.py index 142abab3c15..02e48bc108b 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -767,6 +767,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None: loop.create_task(conn._resolve_host("localhost", 8080)) loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) + await asyncio.sleep(0) m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0) @@ -778,6 +779,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non r1 = loop.create_task(conn._resolve_host("localhost", 8080)) r2 = loop.create_task(conn._resolve_host("localhost", 8080)) await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) + await asyncio.sleep(0) assert r1.exception() == e assert r2.exception() == e @@ -792,6 +796,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close( loop.create_task(conn._resolve_host("localhost", 8080)) f = loop.create_task(conn._resolve_host("localhost", 8080)) + await asyncio.sleep(0) await asyncio.sleep(0) await conn.close() @@ -956,6 +961,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) - loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) loop.create_task(conn._resolve_host("localhost", 8080, traces=traces)) await asyncio.sleep(0) + await asyncio.sleep(0) on_dns_cache_hit.assert_called_once_with( session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost") )