Skip to content

Commit

Permalink
Added basic socket support
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Sep 11, 2018
1 parent aa88a92 commit 2cae087
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 111 deletions.
8 changes: 7 additions & 1 deletion hyperio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
import threading
import typing # noqa: F401
from _socket import AF_INET, SOCK_STREAM
from contextlib import contextmanager
from importlib import import_module
from pathlib import Path
Expand All @@ -9,7 +10,7 @@

from .interfaces import ( # noqa: F401
IPAddressType, StreamingSocket, CancelScope, DatagramSocket, Lock,
Condition, Event, Semaphore, Queue, TaskGroup)
Condition, Event, Semaphore, Queue, TaskGroup, Socket)

T_Retval = TypeVar('T_Retval', covariant=True)
_local = threading.local()
Expand Down Expand Up @@ -118,6 +119,11 @@ def run_async_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval:
# Networking
#

def create_socket(family: int = AF_INET, type: int = SOCK_STREAM, proto: int = 0,
fileno=None) -> Socket:
return _get_asynclib().create_socket(family, type, proto, fileno)


def connect_tcp(
address: IPAddressType, port: int, *,
bind: Union[IPAddressType, Iterable[IPAddressType], None] = None) -> \
Expand Down
303 changes: 196 additions & 107 deletions hyperio/backends/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import asyncio
import errno
import inspect
import socket
import ssl
import sys
from contextlib import closing, suppress
from pathlib import Path
from socket import SocketType
from ssl import SSLContext
from contextlib import suppress
from ipaddress import ip_address
from threading import Thread
from typing import Callable, Set, Optional, List, Union, Iterable, AsyncIterable, Dict # noqa:F401
from typing import Callable, Set, Optional, List, Union, Dict # noqa: F401

from async_generator import async_generator, yield_, asynccontextmanager

from .. import (
interfaces, IPAddressType, StreamingSocket, DatagramSocket, claim_current_thread, _local,
T_Retval)
from ..exceptions import MultiError, DelimiterNotFound, CancelledError
from .. import interfaces, claim_current_thread, _local, T_Retval
from ..exceptions import MultiError, CancelledError

try:
from asyncio import run as native_run, create_task, get_running_loop, current_task
Expand Down Expand Up @@ -326,133 +322,226 @@ def run_async_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval:
class AsyncIOSocket:
__slots__ = '_loop', '_sock'

def __init__(self, sock: SocketType) -> None:
def __init__(self, sock: socket.SocketType) -> None:
self._loop = get_running_loop()
self._sock = sock

def __getattr__(self, item):
return getattr(self._sock, item)

async def __aenter__(self) -> 'AsyncIOSocket':
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
self._sock.close()

async def read(self, size: Optional[int] = None) -> bytes:
return await self._loop.sock_recv(self._sock, size)

async def send(self, data: bytes) -> None:
await self._loop.sock_sendall(self._sock, data)

def __enter__(self):
return self

class AsyncIOStreamingSocket(AsyncIOSocket, interfaces.StreamingSocket):
__slots__ = ()
def __exit__(self, exc_type, exc_val, exc_tb):
self._sock.close()

async def read_exactly(self, nbytes: int) -> bytes:
buf = b''
while nbytes > 0:
data = await self._loop.sock_recv(self._sock, nbytes)
buf += data
nbytes -= len(data)
async def wait_readable(self) -> None:
_check_cancelled()
event = asyncio.Event()
self._loop.add_reader(self._sock.fileno(), event.set)
await event.wait()

return buf
async def wait_writable(self) -> None:
_check_cancelled()
event = asyncio.Event()
self._loop.add_writer(self._sock.fileno(), event.set)
await event.wait()

async def read_until(self, delimiter: bytes, max_size: int) -> bytes:
index = 0
delimiter_size = len(delimiter)
buf = b''
while len(buf) < max_size:
data = await self._loop.sock_recv(self._sock, max_size - len(buf))
buf += data
if buf.find(delimiter, index):
return buf
else:
index += len(data) - delimiter_size + 1
async def accept(self):
_check_cancelled()
try:
raw_socket, address = self._sock.accept()
except BlockingIOError:
await self.wait_readable()
raw_socket, address = self._sock.accept()

raise DelimiterNotFound(
'Maximum number of bytes ({}) read while searching for delimiter ({})'.format(
max_size, delimiter))
raw_socket.setblocking(False)
return AsyncIOSocket(raw_socket), address

async def start_tls(self, ssl_context: SSLContext) -> None:
def ready_callback():
async def bind(self, address: Union[tuple, str, bytes]) -> None:
# For IP address/port combinations, call bind() directly
_check_cancelled()
if isinstance(address, tuple) and len(address) == 2:
try:
sslsock.do_handshake()
except ssl.SSLWantReadError:
print('Want SSL read')
except ssl.SSLWantWriteError:
print('Want SSL write')
except BaseException as exc:
future.set_exception(exc)
ip_address(address[0])
except ValueError:
pass
else:
future.set_result(None)
self._sock.bind(address)
return

# In all other cases, do this in a worker thread to avoid blocking the event loop thread
await run_in_thread(self._sock.bind, address)

sslsock = ssl_context.wrap_socket(self._sock)
future = self._loop.create_future()
self._loop.add_reader(self._sock.fileno(), ready_callback)
self._loop.add_writer(self._sock.fileno(), ready_callback)
async def connect(self, address: Union[tuple, str, bytes]) -> None:
_check_cancelled()
try:
await future
finally:
self._loop.remove_reader(self._sock.fileno())
self._loop.remove_writer(self._sock.fileno())
self._sock.connect(address)
except BlockingIOError:
await self.wait_writable()

self._sock = sslsock
error = self._sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if error:
raise OSError(error, errno.errorcode[error])

async def recv(self, size: int, *, flags: int = 0) -> bytes:
_check_cancelled()
try:
return self._sock.recv(size, flags)
except BlockingIOError:
await self.wait_readable()
return self._sock.recv(size)

class AsyncIODatagramSocket(AsyncIOSocket, interfaces.DatagramSocket):
__slots__ = ()
async def recv_into(self, buffer, nbytes: int, *, flags: int = 0) -> int:
_check_cancelled()
try:
return self._sock.recv_into(buffer, nbytes, flags)
except BlockingIOError:
await self.wait_readable()
return self._sock.recv_into(buffer, nbytes, flags)

async def send(self, data: bytes, address: Optional[IPAddressType] = None) -> None:
if address:
self._sock.connect(str(address))
async def send(self, data: bytes, *, flags: int = 0) -> int:
_check_cancelled()
try:
return self._sock.send(data, flags)
except BlockingIOError:
await self.wait_writable()
return self._sock.send(data, flags)

async def sendall(self, data: bytes, *, flags: int = 0) -> None:
to_send = len(data)
while to_send > 0:
_check_cancelled()
try:
sent = self._sock.send(data, flags)
except BlockingIOError:
await self.wait_writable()
else:
to_send -= sent

await self._loop.sock_sendall(data)

# class AsyncIOStreamingSocket(AsyncIOSocket, interfaces.StreamingSocket):
# __slots__ = ()
#
# async def read_exactly(self, nbytes: int) -> bytes:
# buf = b''
# while nbytes > 0:
# data = await self._loop.sock_recv(self._sock, nbytes)
# buf += data
# nbytes -= len(data)
#
# return buf
#
# async def read_until(self, delimiter: bytes, max_size: int) -> bytes:
# index = 0
# delimiter_size = len(delimiter)
# buf = b''
# while len(buf) < max_size:
# data = await self._loop.sock_recv(self._sock, max_size - len(buf))
# buf += data
# if buf.find(delimiter, index):
# return buf
# else:
# index += len(data) - delimiter_size + 1
#
# raise DelimiterNotFound(
# 'Maximum number of bytes ({}) read while searching for delimiter ({})'.format(
# max_size, delimiter))
#
# async def start_tls(self, ssl_context: SSLContext) -> None:
# def ready_callback():
# try:
# sslsock.do_handshake()
# except ssl.SSLWantReadError:
# print('Want SSL read')
# except ssl.SSLWantWriteError:
# print('Want SSL write')
# except BaseException as exc:
# future.set_exception(exc)
# else:
# future.set_result(None)
#
# sslsock = ssl_context.wrap_socket(self._sock)
# future = self._loop.create_future()
# self._loop.add_reader(self._sock.fileno(), ready_callback)
# self._loop.add_writer(self._sock.fileno(), ready_callback)
# try:
# await future
# finally:
# self._loop.remove_reader(self._sock.fileno())
# self._loop.remove_writer(self._sock.fileno())
#
# self._sock = sslsock

async def connect_tcp(
address: IPAddressType, port: int, *,
bind: Union[IPAddressType, Iterable[IPAddressType], None] = None) -> StreamingSocket:
_check_cancelled()
sock = socket.socket()
sock.setblocking(False)
loop = get_running_loop()
await loop.sock_connect(sock, (address, port))
return AsyncIOStreamingSocket(sock)

# class AsyncIODatagramSocket(AsyncIOSocket, interfaces.DatagramSocket):
# __slots__ = ()
#
# async def send(self, data: bytes, address: Optional[IPAddressType] = None) -> None:
# if address:
# self._sock.connect(str(address))
#
# await self._loop.sock_sendall(data)

async def connect_unix(path: Union[str, Path]) -> StreamingSocket:
_check_cancelled()
sock = socket.socket(socket.AF_UNIX)
sock.setblocking(False)
loop = get_running_loop()
await loop.sock_connect(sock, str(path))
return AsyncIOStreamingSocket(sock)

def create_socket(family: int, type: int, proto: int, fileno) -> interfaces.Socket:
raw_socket = socket.socket(family, type, proto, fileno)
raw_socket.setblocking(False)
return AsyncIOSocket(raw_socket)

@async_generator
async def serve_tcp(
port: int, *, bind: Union[IPAddressType, Iterable[IPAddressType]] = '*',
ssl_context: Optional[SSLContext] = None) -> AsyncIterable[StreamingSocket]:
_check_cancelled()
with closing(socket.socket()) as server_sock:
server_sock.setblocking(False)
server_sock.bind((str(bind), port))
server_sock.listen(5)
while True:
raw_sock, address = await _local.loop.sock_accept(server_sock)
stream = AsyncIOStreamingSocket(raw_sock)
del raw_sock, address
await yield_(stream)


async def create_udp_socket(
*, bind: Union[IPAddressType, Iterable[IPAddressType], None] = None,
target: Optional[IPAddressType] = None) -> DatagramSocket:
_check_cancelled()
sock = socket.socket()
sock.setblocking(False)
if target is not None:
sock.connect(target)

return AsyncIODatagramSocket(sock)
# async def connect_tcp(
# address: IPAddressType, port: int, *,
# bind: Union[IPAddressType, Iterable[IPAddressType], None] = None) -> StreamingSocket:
# _check_cancelled()
# sock = socket.socket()
# sock.setblocking(False)
# loop = get_running_loop()
# await loop.sock_connect(sock, (address, port))
# return AsyncIOStreamingSocket(sock)
#
#
# async def connect_unix(path: Union[str, Path]) -> StreamingSocket:
# _check_cancelled()
# sock = socket.socket(socket.AF_UNIX)
# sock.setblocking(False)
# loop = get_running_loop()
# await loop.sock_connect(sock, str(path))
# return AsyncIOStreamingSocket(sock)
#
#
# @async_generator
# async def serve_tcp(
# port: int, *, bind: Union[IPAddressType, Iterable[IPAddressType]] = '*',
# ssl_context: Optional[SSLContext] = None) -> AsyncIterable[StreamingSocket]:
# _check_cancelled()
# with closing(socket.socket()) as server_sock:
# server_sock.setblocking(False)
# server_sock.bind((str(bind), port))
# server_sock.listen(5)
# while True:
# raw_sock, address = await _local.loop.sock_accept(server_sock)
# stream = AsyncIOStreamingSocket(raw_sock)
# del raw_sock, address
# await yield_(stream)
#
#
# async def create_udp_socket(
# *, bind: Union[IPAddressType, Iterable[IPAddressType], None] = None,
# target: Optional[IPAddressType] = None) -> DatagramSocket:
# _check_cancelled()
# sock = socket.socket()
# sock.setblocking(False)
# if target is not None:
# sock.connect(target)
#
# return AsyncIODatagramSocket(sock)


#
Expand Down Expand Up @@ -513,8 +602,8 @@ def put(self, item):
return super().put(item)


interfaces.CancelScope.register(AsyncIOCancelScope)
interfaces.TaskGroup.register(AsyncIOTaskGroup)
interfaces.Socket.register(AsyncIOSocket)
interfaces.Lock.register(Lock)
interfaces.Condition.register(Condition)
interfaces.Event.register(Event)
Expand Down
Loading

0 comments on commit 2cae087

Please sign in to comment.