Skip to content

Commit

Permalink
[PR #8163/006fbe03 backport][3.9] Avoid creating a task to do DNS res…
Browse files Browse the repository at this point in the history
…olution if there is no throttle (#8172)

Co-authored-by: J. Nick Koston <nick@koston.org>
Fixes #123'). -->
  • Loading branch information
patchback[bot] committed Feb 20, 2024
1 parent 87e0697 commit e74a4a0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CHANGES/8163.bugfix.rst
@@ -0,0 +1,5 @@
Improved the DNS resolution performance on cache hit
-- by :user:`bdraco`.

This is achieved by avoiding an :mod:`asyncio` task creation
in this case.
50 changes: 36 additions & 14 deletions aiohttp/connector.py
Expand Up @@ -825,6 +825,7 @@ def clear_dns_cache(
async def _resolve_host(
self, host: str, port: int, traces: Optional[List["Trace"]] = None
) -> List[Dict[str, Any]]:
"""Resolve host and return list of addresses."""
if is_ip_address(host):
return [
{
Expand Down Expand Up @@ -852,8 +853,7 @@ async def _resolve_host(
return res

key = (host, port)

if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
if key in self._cached_hosts and not self._cached_hosts.expired(key):
# get result early, before any await (#4014)
result = self._cached_hosts.next_addrs(key)

Expand All @@ -862,6 +862,39 @@ async def _resolve_host(
await trace.send_dns_cache_hit(host)
return result

#
# If multiple connectors are resolving the same host, we wait
# for the first one to resolve and then use the result for all of them.
# We use a throttle event to ensure that we only resolve the host once
# and then use the result for all the waiters.
#
# In this case we need to create a task to ensure that we can shield
# the task from cancellation as cancelling this lookup should not cancel
# the underlying lookup or else the cancel event will get broadcast to
# all the waiters across all connections.
#
resolved_host_task = asyncio.create_task(
self._resolve_host_with_throttle(key, host, port, traces)
)
try:
return await asyncio.shield(resolved_host_task)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

resolved_host_task.add_done_callback(drop_exception)
raise

async def _resolve_host_with_throttle(
self,
key: Tuple[str, int],
host: str,
port: int,
traces: Optional[List["Trace"]],
) -> List[Dict[str, Any]]:
"""Resolve host with a dns events throttle."""
if key in self._throttle_dns_events:
# get event early, before any await (#4014)
event = self._throttle_dns_events[key]
Expand Down Expand Up @@ -1163,22 +1196,11 @@ async def _create_direct_connection(
host = host.rstrip(".") + "."
port = req.port
assert port is not None
host_resolved = asyncio.ensure_future(
self._resolve_host(host, port, traces=traces), loop=self._loop
)
try:
# Cancelling this lookup should not cancel the underlying lookup
# or else the cancel event will get broadcast to all the waiters
# across all connections.
hosts = await asyncio.shield(host_resolved)
except asyncio.CancelledError:

def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
with suppress(Exception, asyncio.CancelledError):
fut.result()

host_resolved.add_done_callback(drop_exception)
raise
hosts = await self._resolve_host(host, port, traces=traces)
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
raise
Expand Down
6 changes: 6 additions & 0 deletions tests/test_connector.py
Expand Up @@ -767,6 +767,7 @@ async def test_tcp_connector_dns_throttle_requests(loop, dns_response) -> None:
loop.create_task(conn._resolve_host("localhost", 8080))
loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
m_resolver().resolve.assert_called_once_with("localhost", 8080, family=0)


Expand All @@ -778,6 +779,9 @@ async def test_tcp_connector_dns_throttle_requests_exception_spread(loop) -> Non
r1 = loop.create_task(conn._resolve_host("localhost", 8080))
r2 = loop.create_task(conn._resolve_host("localhost", 8080))
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
assert r1.exception() == e
assert r2.exception() == e

Expand All @@ -792,6 +796,7 @@ async def test_tcp_connector_dns_throttle_requests_cancelled_when_close(
loop.create_task(conn._resolve_host("localhost", 8080))
f = loop.create_task(conn._resolve_host("localhost", 8080))

await asyncio.sleep(0)
await asyncio.sleep(0)
await conn.close()

Expand Down Expand Up @@ -956,6 +961,7 @@ async def test_tcp_connector_dns_tracing_throttle_requests(loop, dns_response) -
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
loop.create_task(conn._resolve_host("localhost", 8080, traces=traces))
await asyncio.sleep(0)
await asyncio.sleep(0)
on_dns_cache_hit.assert_called_once_with(
session, trace_config_ctx, aiohttp.TraceDnsCacheHitParams("localhost")
)
Expand Down

0 comments on commit e74a4a0

Please sign in to comment.