Skip to content

Commit

Permalink
Refactored the Listener API
Browse files Browse the repository at this point in the history
This replaces accept() with serve() which in turn obsoletes anyio.serve_listeners() which did not make it into any release.
Among other things, this lets us do the TLS handshake in the newly spawned handler task.

Fixes #125.
  • Loading branch information
agronholm committed Aug 4, 2020
1 parent 88c6292 commit 58618bd
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 102 deletions.
5 changes: 3 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ Streams and stream wrappers

.. autofunction:: anyio.create_memory_object_stream

.. autofunction:: anyio.serve_listeners

.. autoclass:: anyio.streams.buffered.BufferedByteReceiveStream
:members:

Expand All @@ -81,6 +79,9 @@ Streams and stream wrappers
.. autoclass:: anyio.streams.memory.MemoryObjectSendStream
:members:

.. autoclass:: anyio.streams.stapled.MultiListener
:members:

.. autoclass:: anyio.streams.stapled.StapledByteStream
:members:

Expand Down
24 changes: 12 additions & 12 deletions docs/networking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ If you need to establish a TLS session over TCP, you can use :func:`~anyio.conne
a convenience (instead of wrapping the stream with :meth:`anyio.streams.tls.TLSStream.wrap` after
a successful connection).

To receive incoming TCP connections, you first create TCP listeners with
:func:`anyio.create_tcp_listeners` and then pass them to :func:`~anyio.serve_listeners`::
To receive incoming TCP connections, you first create a TCP listener with
:func:`anyio.create_tcp_listener` and call :meth:`~anyio.abc.streamsListener.serve` on it::

from anyio import create_tcp_listeners, serve_listeners, run
from anyio import create_tcp_listeners, run


async def handle(client):
async with client:
name = await client.receive(1024)
await client.send_all(b'Hello, %s\n' % name)
await client.send(b'Hello, %s\n' % name)


async def main():
listeners = create_tcp_listeners(local_port=1234)
await serve_listeners(handle, listeners)
listener = create_tcp_listener(local_port=1234)
await listener.serve(handle)

run(main)

Expand All @@ -76,26 +76,26 @@ This is what the client from the TCP example looks like when converted to use UN

async def main():
async with await connect_unix('/tmp/mysock') as client:
await client.send_all(b'Client\n')
response = await client.receive_until(b'\n', 1024)
await client.send(b'Client\n')
response = await client.receive(1024)
print(response)

run(main)

And the listener::

from anyio import create_unix_listener, serve_listeners, run
from anyio import create_unix_listener, run


async def handle(client):
async with client:
name = await client.receive_until(b'\n', 1024)
await client.send_all(b'Hello, %s\n' % name)
name = await client.receive(1024)
await client.send(b'Hello, %s\n' % name)


async def main():
listener = await create_unix_listener('/tmp/mysock')
await serve_listeners(handle, [listener])
await listener.serve(handle)

run(main)

Expand Down
30 changes: 6 additions & 24 deletions src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from ._utils import convert_ipv6_sockaddr
from .abc import (
Lock, Condition, Event, Semaphore, CapacityLimiter, CancelScope, TaskGroup, IPAddressType,
SocketStream, UDPSocket, ConnectedUDPSocket, IPSockAddrType, Listener, SocketListener, Process,
SocketStream, UDPSocket, ConnectedUDPSocket, IPSockAddrType, SocketListener, Process,
AsyncResource)
from .fileio import AsyncFile
from .streams.stapled import MultiListener
from .streams.tls import TLSStream
from .streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

Expand Down Expand Up @@ -518,11 +519,11 @@ async def connect_unix(path: Union[str, 'os.PathLike']) -> SocketStream:
return await _get_asynclib().connect_unix(path)


async def create_tcp_listeners(
async def create_tcp_listener(
*, local_host: Optional[IPAddressType] = None, local_port: int = 0,
family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC, backlog: int = 65536,
reuse_port: bool = False
) -> List[SocketListener[IPSockAddrType]]:
) -> MultiListener[SocketStream[IPSockAddrType]]:
"""
Create a TCP socket listener.
Expand All @@ -544,7 +545,7 @@ async def create_tcp_listeners(
gai_res = await getaddrinfo(local_host, local_port, family=family, # type: ignore[arg-type]
type=socket.SOCK_STREAM,
flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG)
listeners = []
listeners: List[SocketListener[IPSockAddrType]] = []
try:
# The set() is here to work around a glibc bug:
# https://sourceware.org/bugzilla/show_bug.cgi?id=14969
Expand Down Expand Up @@ -575,7 +576,7 @@ async def create_tcp_listeners(

raise

return listeners
return MultiListener(listeners)


async def create_unix_listener(
Expand Down Expand Up @@ -609,25 +610,6 @@ async def create_unix_listener(
raise


async def serve_listeners(handler: Callable[[Any], Any], listeners: typing.Iterable[Listener], *,
handler_task_group: Optional[TaskGroup] = None) -> None:
async def serve_listener(listener: Listener) -> typing.NoReturn:
async with listener:
while True:
# TODO: handle the same OSErrors as trio does
stream = await listener.accept()
try:
await listener_task_group.spawn(handler, stream)
except BaseException:
await stream.aclose()
raise

async with create_task_group() as tg:
listener_task_group = handler_task_group or tg
for listener in listeners:
await tg.spawn(serve_listener, listener)


async def create_udp_socket(
family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, *,
local_host: Optional[IPAddressType] = None, local_port: int = 0, reuse_port: bool = False
Expand Down
8 changes: 8 additions & 0 deletions src/anyio/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ def convert_ipv6_sockaddr(sockaddr):
return sockaddr[:2]
else:
return sockaddr


class NullAsyncContextManager:
async def __aenter__(self):
pass

async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
28 changes: 26 additions & 2 deletions src/anyio/abc/sockets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import abstractmethod
from ipaddress import IPv4Address, IPv6Address
from socket import AddressFamily
from typing import TypeVar, Tuple, Union, Generic
from typing import (
TypeVar, Tuple, Union, Generic, Callable, Any, Optional, AsyncContextManager)

from .streams import UnreliableObjectStream, ByteStream, Listener
from .tasks import TaskGroup
from .streams import UnreliableObjectStream, ByteStream, Listener, T_Stream

IPAddressType = Union[str, IPv4Address, IPv6Address]
IPSockAddrType = Tuple[str, int]
Expand Down Expand Up @@ -64,6 +66,28 @@ class SocketListener(Generic[T_SockAddr], Listener[SocketStream[T_SockAddr]],
_SocketMixin[T_SockAddr]):
"""Listens to incoming socket connections."""

@abstractmethod
async def accept(self) -> SocketStream[T_SockAddr]:
"""Accept an incoming connection."""

async def serve(self, handler: Callable[[T_Stream], Any],
task_group: Optional[TaskGroup] = None) -> None:
from .. import create_task_group
from .._utils import NullAsyncContextManager

context_manager: AsyncContextManager
if task_group is None:
task_group = context_manager = create_task_group()
else:
# Can be replaced with AsyncExitStack once on py3.7+
context_manager = NullAsyncContextManager()

# There is a mypy bug here
async with context_manager: # type: ignore[attr-defined]
while True:
stream = await self.accept()
await task_group.spawn(handler, stream)


class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketMixin[IPSockAddrType]):
"""Represents an unconnected UDP socket."""
Expand Down
34 changes: 14 additions & 20 deletions src/anyio/abc/streams.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import abstractmethod
from typing import Generic, TypeVar, Union
from typing import Generic, TypeVar, Union, Callable, Any, Optional

from .resource import AsyncResource
from ..exceptions import EndOfStream, ClosedResourceError
from .tasks import TaskGroup
from ..exceptions import EndOfStream

T_Item = TypeVar('T_Item')
T_Stream = TypeVar('T_Stream', bound='AnyByteStream', covariant=True)
T_Stream = TypeVar('T_Stream')


class UnreliableObjectReceiveStream(Generic[T_Item], AsyncResource):
Expand Down Expand Up @@ -167,22 +168,15 @@ async def send_eof(self) -> None:


class Listener(Generic[T_Stream], AsyncResource):
"""
An interface for objects that let you accept incoming connections.
Asynchronously iterating over this object will yield streams matching the type parameter
given for this interface.
"""

def __aiter__(self):
return self

async def __anext__(self) -> T_Stream:
try:
return await self.accept()
except ClosedResourceError:
raise StopAsyncIteration
"""An interface for objects that let you accept incoming connections."""

@abstractmethod
async def accept(self) -> T_Stream:
"""Accept an incoming connection."""
async def serve(self, handler: Callable[[T_Stream], Any],
task_group: Optional[TaskGroup] = None) -> None:
"""
Accept incoming connections as they come in and spawn tasks to handle them.
:param handler: a callable that will be used to handle each accepted connection
:param task_group: the task group that will be used to spawn tasks for handling each
accepted connection (if omitted, an ad-hoc task group will be created)
"""
44 changes: 42 additions & 2 deletions src/anyio/streams/stapled.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from dataclasses import dataclass
from typing import TypeVar, Generic
from typing import TypeVar, Generic, Callable, Any, Optional, Sequence

from ..abc import (
ByteReceiveStream, ByteStream, ByteSendStream, ObjectStream, ObjectReceiveStream,
ObjectSendStream)
ObjectSendStream, Listener, TaskGroup)

T_Item = TypeVar('T_Item')
T_Stream = TypeVar('T_Stream')


@dataclass
Expand Down Expand Up @@ -58,3 +59,42 @@ async def send_eof(self) -> None:
async def aclose(self) -> None:
await self.send_stream.aclose()
await self.receive_stream.aclose()


@dataclass
class MultiListener(Generic[T_Stream], Listener[T_Stream]):
"""
Combines multiple listeners into one, serving connections from all of them at once.
Any MultiListeners in the given collection of listeners will have their listeners moved into
this one.
:param listeners: listeners to serve
:type listeners: Sequence[Listener[T_Stream]]
"""

listeners: Sequence[Listener[T_Stream]]

def __post_init__(self):
listeners = []
for listener in self.listeners:
if isinstance(listener, MultiListener):
listeners.extend(listener.listeners)
del listener.listeners[:]
else:
listeners.append(listener)

self.listeners = listeners

async def serve(self, handler: Callable[[T_Stream], Any],
task_group: Optional[TaskGroup] = None) -> None:
from .. import create_task_group

# There is a mypy bug here
async with create_task_group() as tg: # type: ignore[attr-defined]
for listener in self.listeners:
await tg.spawn(listener.serve, handler, task_group)

async def aclose(self) -> None:
for listener in self.listeners:
await listener.aclose()
16 changes: 10 additions & 6 deletions src/anyio/streams/tls.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import ssl
import sys
from dataclasses import dataclass
from typing import Optional, Callable, Tuple, overload, List, Dict, Union, TypeVar
from typing import Optional, Callable, Tuple, overload, List, Dict, Union, TypeVar, Any

from ..abc import ByteStream, AnyByteStream, Listener
from ..abc import ByteStream, AnyByteStream, Listener, TaskGroup
from ..exceptions import EndOfStream, BrokenResourceError

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -197,10 +197,14 @@ def __post_init__(self):
if self.context is None:
self.context = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH)

async def accept(self) -> TLSStream:
transport_stream = await self.listener.accept()
return await TLSStream.wrap(transport_stream, ssl_context=self.context,
standard_compatible=self.standard_compatible)
async def serve(self, handler: Callable[[TLSStream], Any],
task_group: Optional[TaskGroup] = None) -> None:
async def handler_wrapper(stream: AnyByteStream):
wrapped_stream = await TLSStream.wrap(stream, ssl_context=self.context,
standard_compatible=self.standard_compatible)
await handler(wrapped_stream)

await self.listener.serve(handler_wrapper, task_group)

async def aclose(self) -> None:
await self.listener.aclose()

0 comments on commit 58618bd

Please sign in to comment.