Skip to content

Commit

Permalink
Implement happy eyeballs (RFC 8305) (#7954)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco committed Jan 5, 2024
1 parent 2393146 commit c4ec3f1
Show file tree
Hide file tree
Showing 13 changed files with 477 additions and 42 deletions.
1 change: 1 addition & 0 deletions CHANGES/7954.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Implement happy eyeballs
67 changes: 51 additions & 16 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import random
import socket
import sys
import traceback
import warnings
Expand Down Expand Up @@ -31,6 +32,8 @@
cast,
)

import aiohappyeyeballs

from . import hdrs, helpers
from .abc import AbstractResolver
from .client_exceptions import (
Expand Down Expand Up @@ -730,6 +733,10 @@ class TCPConnector(BaseConnector):
limit_per_host - Number of simultaneous connections to one host.
enable_cleanup_closed - Enables clean-up closed ssl transports.
Disabled by default.
happy_eyeballs_delay - This is the “Connection Attempt Delay”
as defined in RFC 8305. To disable
the happy eyeballs algorithm, set to None.
interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
"""

Expand All @@ -748,6 +755,8 @@ def __init__(
limit_per_host: int = 0,
enable_cleanup_closed: bool = False,
timeout_ceil_threshold: float = 5,
happy_eyeballs_delay: Optional[float] = 0.25,
interleave: Optional[int] = None,
) -> None:
super().__init__(
keepalive_timeout=keepalive_timeout,
Expand All @@ -772,7 +781,9 @@ def __init__(
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {}
self._family = family
self._local_addr = local_addr
self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
self._happy_eyeballs_delay = happy_eyeballs_delay
self._interleave = interleave

def _close_immediately(self) -> List["asyncio.Future[None]"]:
for ev in self._throttle_dns_events.values():
Expand Down Expand Up @@ -956,6 +967,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
addr_infos: List[aiohappyeyeballs.AddrInfoType],
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
Expand All @@ -965,7 +977,14 @@ async def _wrap_create_connection(
async with ceil_timeout(
timeout.sock_connect, ceil_threshold=timeout.ceil_threshold
):
return await self._loop.create_connection(*args, **kwargs)
sock = await aiohappyeyeballs.start_connection(
addr_infos=addr_infos,
local_addr_infos=self._local_addr_infos,
happy_eyeballs_delay=self._happy_eyeballs_delay,
interleave=self._interleave,
loop=self._loop,
)
return await self._loop.create_connection(*args, **kwargs, sock=sock)
except cert_errors as exc:
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
except ssl_errors as exc:
Expand Down Expand Up @@ -1076,6 +1095,27 @@ async def _start_tls_connection(

return tls_transport, tls_proto

def _convert_hosts_to_addr_infos(
self, hosts: List[Dict[str, Any]]
) -> List[aiohappyeyeballs.AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.
The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
for hinfo in hosts:
host = hinfo["host"]
is_ipv6 = ":" in host
family = socket.AF_INET6 if is_ipv6 else socket.AF_INET
if self._family and self._family != family:
continue
addr = (host, hinfo["port"], 0, 0) if is_ipv6 else (host, hinfo["port"])
addr_infos.append(
(family, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", addr)
)
return addr_infos

async def _create_direct_connection(
self,
req: ClientRequest,
Expand Down Expand Up @@ -1120,36 +1160,27 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
raise ClientConnectorError(req.connection_key, exc) from exc

last_exc: Optional[Exception] = None

for hinfo in hosts:
host = hinfo["host"]
port = hinfo["port"]

addr_infos = self._convert_hosts_to_addr_infos(hosts)
while addr_infos:
# Strip trailing dots, certificates contain FQDN without dots.
# See https://github.com/aio-libs/aiohttp/issues/3636
server_hostname = (
(req.server_hostname or hinfo["hostname"]).rstrip(".")
if sslcontext
else None
(req.server_hostname or host).rstrip(".") if sslcontext else None
)

try:
transp, proto = await self._wrap_create_connection(
self._factory,
host,
port,
timeout=timeout,
ssl=sslcontext,
family=hinfo["family"],
proto=hinfo["proto"],
flags=hinfo["flags"],
addr_infos=addr_infos,
server_hostname=server_hostname,
local_addr=self._local_addr,
req=req,
client_error=client_error,
)
except ClientConnectorError as exc:
last_exc = exc
aiohappyeyeballs.pop_addr_infos_interleave(addr_infos, self._interleave)
continue

if req.is_ssl() and fingerprint:
Expand All @@ -1160,6 +1191,10 @@ def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
if not self._cleanup_closed_disabled:
self._cleanup_closed_transports.append(transp)
last_exc = exc
# Remove the bad peer from the list of addr_infos
sock: socket.socket = transp.get_extra_info("socket")
bad_peer = sock.getpeername()
aiohappyeyeballs.remove_addr_infos(addr_infos, bad_peer)
continue

return transp, proto
Expand Down
21 changes: 20 additions & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,8 @@ is controlled by *force_close* constructor's parameter).
family=0, ssl_context=None, local_addr=None, \
resolver=None, keepalive_timeout=sentinel, \
force_close=False, limit=100, limit_per_host=0, \
enable_cleanup_closed=False, loop=None)
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
happy_eyeballs_delay=0.25, interleave=None, loop=None)

Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.

Expand Down Expand Up @@ -1158,6 +1159,24 @@ is controlled by *force_close* constructor's parameter).
If this parameter is set to True, aiohttp additionally aborts underlining
transport after 2 seconds. It is off by default.

:param float happy_eyeballs_delay: The amount of time in seconds to wait for a
connection attempt to complete, before starting the next attempt in parallel.
This is the “Connection Attempt Delay” as defined in RFC 8305. To disable
Happy Eyeballs, set this to ``None``. The default value recommended by the
RFC is 0.25 (250 milliseconds).

.. versionadded:: 3.10

:param int interleave: controls address reordering when a host name resolves
to multiple IP addresses. If ``0`` or unspecified, no reordering is done, and
addresses are tried in the order returned by the resolver. If a positive
integer is specified, the addresses are interleaved by address family, and
the given integer is interpreted as “First Address Family Count” as defined
in RFC 8305. The default is ``0`` if happy_eyeballs_delay is not specified, and
``1`` if it is.

.. versionadded:: 3.10

.. attribute:: family

*TCP* socket family e.g. :data:`socket.AF_INET` or
Expand Down
2 changes: 2 additions & 0 deletions requirements/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.0
# via -r requirements/runtime-deps.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
async-timeout==4.0.3 ; python_version < "3.11"
Expand Down
2 changes: 2 additions & 0 deletions requirements/constraints.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.0
# via -r requirements/runtime-deps.in
aiohttp-theme==0.1.6
# via -r requirements/doc.in
aioredis==2.0.1
Expand Down
2 changes: 2 additions & 0 deletions requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.0
# via -r requirements/runtime-deps.in
aiohttp-theme==0.1.6
# via -r requirements/doc.in
aioredis==2.0.1
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime-deps.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`

aiodns >= 1.1; sys_platform=="linux" or sys_platform=="darwin"
aiohappyeyeballs >= 2.3.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 5.0 ; python_version < "3.11"
Brotli; platform_python_implementation == 'CPython'
Expand Down
2 changes: 2 additions & 0 deletions requirements/runtime-deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.0
# via -r requirements/runtime-deps.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
async-timeout==4.0.3 ; python_version < "3.11"
Expand Down
2 changes: 2 additions & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
aiodns==3.1.1 ; sys_platform == "linux" or sys_platform == "darwin"
# via -r requirements/runtime-deps.in
aiohappyeyeballs==2.3.0
# via -r requirements/runtime-deps.in
aiosignal==1.3.1
# via -r requirements/runtime-deps.in
async-timeout==4.0.3 ; python_version < "3.11"
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ zip_safe = False
include_package_data = True

install_requires =
aiohappyeyeballs >= 2.3.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 5.0 ; python_version < "3.11"
frozenlist >= 1.1.1
Expand Down
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,13 @@ def netrc_contents(
monkeypatch.setenv("NETRC", str(netrc_file_path))

return netrc_file_path


@pytest.fixture
def start_connection():
with mock.patch(
"aiohttp.connector.aiohappyeyeballs.start_connection",
autospec=True,
spec_set=True,
) as start_connection_mock:
yield start_connection_mock

0 comments on commit c4ec3f1

Please sign in to comment.