Skip to content

Commit

Permalink
Fix loop.getaddrinfo() and tests (#495)
Browse files Browse the repository at this point in the history
* ai_canonname always follows the flag AI_CANONNAME in static resolving (#494)
* AddressFamily and SocketKind can be enums
* Also fixed failing test
  • Loading branch information
fantix committed Sep 13, 2022
1 parent d6a2b59 commit 598b16f
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 15 deletions.
66 changes: 58 additions & 8 deletions tests/test_dns.py
Expand Up @@ -5,12 +5,32 @@
from uvloop import _testbase as tb


def patched_getaddrinfo(*args, **kwargs):
# corrected socket.getaddrinfo() behavior: ai_canonname always follows the
# flag AI_CANONNAME, even if `host` is an IP
rv = []
result = socket.getaddrinfo(*args, **kwargs)
for af, sk, proto, canon_name, addr in result:
if kwargs.get('flags', 0) & socket.AI_CANONNAME:
if not canon_name:
canon_name = args[0]
if not isinstance(canon_name, str):
canon_name = canon_name.decode('ascii')
elif canon_name:
canon_name = ''
rv.append((af, sk, proto, canon_name, addr))
return rv


class BaseTestDNS:

def _test_getaddrinfo(self, *args, **kwargs):
def _test_getaddrinfo(self, *args, _patch=False, **kwargs):
err = None
try:
a1 = socket.getaddrinfo(*args, **kwargs)
if _patch:
a1 = patched_getaddrinfo(*args, **kwargs)
else:
a1 = socket.getaddrinfo(*args, **kwargs)
except socket.gaierror as ex:
err = ex

Expand Down Expand Up @@ -100,20 +120,36 @@ def test_getaddrinfo_11(self):
self._test_getaddrinfo(b'example.com', '80', type=socket.SOCK_STREAM)

def test_getaddrinfo_12(self):
# musl always returns ai_canonname but we don't
patch = self.implementation != 'asyncio'

self._test_getaddrinfo('127.0.0.1', '80')
self._test_getaddrinfo('127.0.0.1', '80', type=socket.SOCK_STREAM)
self._test_getaddrinfo('127.0.0.1', '80', type=socket.SOCK_STREAM,
_patch=patch)

def test_getaddrinfo_13(self):
# musl always returns ai_canonname but we don't
patch = self.implementation != 'asyncio'

self._test_getaddrinfo(b'127.0.0.1', b'80')
self._test_getaddrinfo(b'127.0.0.1', b'80', type=socket.SOCK_STREAM)
self._test_getaddrinfo(b'127.0.0.1', b'80', type=socket.SOCK_STREAM,
_patch=patch)

def test_getaddrinfo_14(self):
# musl always returns ai_canonname but we don't
patch = self.implementation != 'asyncio'

self._test_getaddrinfo(b'127.0.0.1', b'http')
self._test_getaddrinfo(b'127.0.0.1', b'http', type=socket.SOCK_STREAM)
self._test_getaddrinfo(b'127.0.0.1', b'http', type=socket.SOCK_STREAM,
_patch=patch)

def test_getaddrinfo_15(self):
# musl always returns ai_canonname but we don't
patch = self.implementation != 'asyncio'

self._test_getaddrinfo('127.0.0.1', 'http')
self._test_getaddrinfo('127.0.0.1', 'http', type=socket.SOCK_STREAM)
self._test_getaddrinfo('127.0.0.1', 'http', type=socket.SOCK_STREAM,
_patch=patch)

def test_getaddrinfo_16(self):
self._test_getaddrinfo('localhost', 'http')
Expand All @@ -128,12 +164,26 @@ def test_getaddrinfo_18(self):
self._test_getaddrinfo('localhost', b'http', type=socket.SOCK_STREAM)

def test_getaddrinfo_19(self):
# musl always returns ai_canonname while macOS never return for IPs,
# but we strictly follow the docs to use the AI_CANONNAME flag
patch = self.implementation != 'asyncio'

self._test_getaddrinfo('::1', 80)
self._test_getaddrinfo('::1', 80, type=socket.SOCK_STREAM)
self._test_getaddrinfo('::1', 80, type=socket.SOCK_STREAM,
_patch=patch)
self._test_getaddrinfo('::1', 80, type=socket.SOCK_STREAM,
flags=socket.AI_CANONNAME, _patch=patch)

def test_getaddrinfo_20(self):
# musl always returns ai_canonname while macOS never return for IPs,
# but we strictly follow the docs to use the AI_CANONNAME flag
patch = self.implementation != 'asyncio'

self._test_getaddrinfo('127.0.0.1', 80)
self._test_getaddrinfo('127.0.0.1', 80, type=socket.SOCK_STREAM)
self._test_getaddrinfo('127.0.0.1', 80, type=socket.SOCK_STREAM,
_patch=patch)
self._test_getaddrinfo('127.0.0.1', 80, type=socket.SOCK_STREAM,
flags=socket.AI_CANONNAME, _patch=patch)

######

Expand Down
2 changes: 1 addition & 1 deletion tests/test_tcp.py
Expand Up @@ -222,7 +222,7 @@ def test_create_server_4(self):

with self.assertRaisesRegex(OSError,
r"error while attempting.*\('127.*: "
r"address already in use"):
r"address( already)? in use"):

self.loop.run_until_complete(
self.loop.create_server(object, *addr))
Expand Down
27 changes: 24 additions & 3 deletions uvloop/dns.pyx
Expand Up @@ -245,7 +245,21 @@ cdef __static_getaddrinfo_pyaddr(object host, object port,
except Exception:
return

return af, type, proto, '', pyaddr
if flags & socket_AI_CANONNAME:
if isinstance(host, str):
canon_name = host
else:
canon_name = host.decode('ascii')
else:
canon_name = ''

return (
_intenum_converter(af, socket_AddressFamily),
_intenum_converter(type, socket_SocketKind),
proto,
canon_name,
pyaddr,
)


@cython.freelist(DEFAULT_FREELIST_SIZE)
Expand Down Expand Up @@ -276,8 +290,8 @@ cdef class AddrInfo:
while ptr != NULL:
if ptr.ai_addr.sa_family in (uv.AF_INET, uv.AF_INET6):
result.append((
ptr.ai_family,
ptr.ai_socktype,
_intenum_converter(ptr.ai_family, socket_AddressFamily),
_intenum_converter(ptr.ai_socktype, socket_SocketKind),
ptr.ai_protocol,
('' if ptr.ai_canonname is NULL else
(<bytes>ptr.ai_canonname).decode()),
Expand Down Expand Up @@ -370,6 +384,13 @@ cdef class NameInfoRequest(UVRequest):
self.callback(convert_error(err))


cdef _intenum_converter(value, enum_klass):
try:
return enum_klass(value)
except ValueError:
return value


cdef void __on_addrinfo_resolved(uv.uv_getaddrinfo_t *resolver,
int status, system.addrinfo *res) with gil:

Expand Down
1 change: 1 addition & 0 deletions uvloop/includes/stdlib.pxi
Expand Up @@ -72,6 +72,7 @@ cdef int has_SO_REUSEPORT = hasattr(socket, 'SO_REUSEPORT')
cdef int SO_REUSEPORT = getattr(socket, 'SO_REUSEPORT', 0)
cdef int SO_BROADCAST = getattr(socket, 'SO_BROADCAST')
cdef int SOCK_NONBLOCK = getattr(socket, 'SOCK_NONBLOCK', -1)
cdef int socket_AI_CANONNAME = getattr(socket, 'AI_CANONNAME')

cdef socket_gaierror = socket.gaierror
cdef socket_error = socket.error
Expand Down
4 changes: 1 addition & 3 deletions uvloop/loop.pyx
Expand Up @@ -1527,9 +1527,7 @@ cdef class Loop:
addr = __static_getaddrinfo_pyaddr(host, port, family,
type, proto, flags)
if addr is not None:
fut = self._new_future()
fut.set_result([addr])
return await fut
return [addr]

return await self._getaddrinfo(
host, port, family, type, proto, flags, 1)
Expand Down

0 comments on commit 598b16f

Please sign in to comment.