Skip to content

Commit

Permalink
Added properties and methods to get TLS metadata from SocketStream
Browse files Browse the repository at this point in the history
Fixes nedbat#6.
  • Loading branch information
agronholm committed Oct 21, 2018
1 parent 6735505 commit 304b3c5
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 6 deletions.
56 changes: 54 additions & 2 deletions anyio/_networking.py
Expand Up @@ -3,13 +3,13 @@
import ssl
from abc import ABCMeta, abstractmethod
from ipaddress import ip_address
from typing import Union, Tuple, Any, Optional, Callable, AsyncIterable
from typing import Union, Tuple, Any, Optional, Callable, AsyncIterable, Dict, List

from async_generator import async_generator, yield_

from anyio import abc
from anyio.abc import IPAddressType, BufferType
from anyio.exceptions import DelimiterNotFound, IncompleteRead
from anyio.exceptions import DelimiterNotFound, IncompleteRead, TLSRequired


class BaseSocket(metaclass=ABCMeta):
Expand Down Expand Up @@ -244,10 +244,62 @@ async def receive_delimited_chunks(self, delimiter: bytes,
async def send_all(self, data: BufferType) -> None:
return await self._socket.sendall(data)

#
# TLS methods
#

def _call_sslsocket_method(self, name: str, *args):
try:
method = getattr(self._socket, name)
except AttributeError:
raise TLSRequired from None

return method(*args)

async def start_tls(self, context: Optional[ssl.SSLContext] = None) -> None:
ssl_context = context or self._ssl_context or ssl.create_default_context()
await self._socket.start_tls(ssl_context, self._server_hostname)

def getpeercert(self, binary_form: bool = False) -> Union[Dict[str, Union[str, tuple]],
bytes, None]:
return self._call_sslsocket_method('getpeercert', binary_form)

@property
def alpn_protocol(self) -> Optional[str]:
return self._call_sslsocket_method('selected_alpn_protocol')

def get_channel_binding(self, cb_type: str = 'tls-unique') -> bytes:
return self._call_sslsocket_method('get_channel_binding', cb_type)

@property
def tls_version(self) -> Optional[str]:
try:
return self._call_sslsocket_method('version')
except TLSRequired:
return None

@property
def cipher(self) -> Tuple[str, str, int]:
return self._call_sslsocket_method('cipher')

@property
def shared_ciphers(self) -> List[Tuple[str, str, int]]:
return self._call_sslsocket_method('shared_ciphers')

@property
def server_hostname(self) -> str:
try:
return self._socket.server_hostname
except AttributeError:
raise TLSRequired from None

@property
def server_side(self) -> bool:
try:
return self._socket.server_side
except AttributeError:
raise TLSRequired from None


class SocketStreamServer(abc.SocketStreamServer):
__slots__ = '_socket', '_ssl_context'
Expand Down
94 changes: 93 additions & 1 deletion anyio/abc.py
Expand Up @@ -2,7 +2,7 @@
from io import SEEK_SET
from ipaddress import IPv4Address, IPv6Address
from ssl import SSLContext
from typing import Callable, TypeVar, Optional, Tuple, Union, AsyncIterable
from typing import Callable, TypeVar, Optional, Tuple, Union, AsyncIterable, Dict, List

T_Retval = TypeVar('T_Retval')
IPAddressType = Union[str, IPv4Address, IPv6Address]
Expand Down Expand Up @@ -311,6 +311,98 @@ async def start_tls(self, context: Optional[SSLContext] = None) -> None:
:param context: an explicit SSL context to use for the handshake
"""

@abstractmethod
def getpeercert(self, binary_form: bool = False) -> Union[Dict[str, Union[str, tuple]],
bytes, None]:
"""
Get the certificate for the peer on the other end of the connection.
See :func:`ssl.SSLSocket.getpeercert` for more information.
:param binary_form: ``False`` to return the certificate as a dict, ``True`` to return it
as bytes
:return: the peer's certificate, or ``None`` if there is not certificate for the peer
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@property
@abstractmethod
def alpn_protocol(self) -> Optional[str]:
"""
Return the ALPN protocol selected during the TLS handshake.
:return: The selected ALPN protocol, or ``None`` if no ALPN protocol was selected
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@abstractmethod
def get_channel_binding(self, cb_type: str = 'tls-unique') -> bytes:
"""
Get the channel binding data for the current connection.
See :func:`ssl.SSLSocket.get_channel_binding` for more information.
:param cb_type: type of the channel binding to get
:return: the channel binding data
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@property
@abstractmethod
def tls_version(self) -> Optional[str]:
"""
Return the TLS version negotiated during the TLS handshake.
See :func:`ssl.SSLSocket.version` for more information.
:return: the TLS version string (e.g. "TLSv1.3"), or ``None`` if the underlying socket is
not using TLS
"""

@property
@abstractmethod
def cipher(self) -> Tuple[str, str, int]:
"""
Return the cipher selected in the TLS handshake.
See :func:`ssl.SSLSocket.cipher` for more information.
:return: a 3-tuple of (cipher name, TLS version which defined it, number of bits)
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@property
@abstractmethod
def shared_ciphers(self) -> List[Tuple[str, str, int]]:
"""
Return the list of ciphers supported by both parties in the TLS handshake.
See :func:`ssl.SSLSocket.shared_ciphers` for more information.
:return: a list of 3-tuples (cipher name, TLS version which defined it, number of bits)
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@property
@abstractmethod
def server_hostname(self) -> Optional[str]:
"""
Return the server host name.
:return: the server host name, or ``None`` if this is the server side of the connection
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""

@property
@abstractmethod
def server_side(self) -> bool:
"""
Check if this is the server or client side of the connection.
:return: ``True`` if this is the server side, ``False`` if this is the client side
:raises anyio.exceptions.TLSRequired: if a TLS handshake has not been done
"""


class SocketStreamServer(metaclass=ABCMeta):
async def __aenter__(self):
Expand Down
4 changes: 4 additions & 0 deletions anyio/exceptions.py
Expand Up @@ -60,3 +60,7 @@ def __init__(self, data: bytes) -> None:

class ClosedResourceError(Exception):
"""Raised when a resource is closed by another task."""


class TLSRequired(Exception):
"""Raised when a TLS related stream method is called before the TLS handshake has been done."""
38 changes: 35 additions & 3 deletions tests/test_networking.py
Expand Up @@ -146,36 +146,68 @@ async def server():


class TestTLSStream:
@pytest.fixture(scope='class')
@pytest.fixture
def server_context(self):
server_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_context.load_cert_chain(certfile=str(Path(__file__).with_name('cert.pem')),
keyfile=str(Path(__file__).with_name('key.pem')))
return server_context

@pytest.fixture(scope='class')
@pytest.fixture
def client_context(self):
client_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
client_context.load_verify_locations(cafile=str(Path(__file__).with_name('cert.pem')))
return client_context

@pytest.mark.anyio
async def test_connect_tcp_tls(self, server_context, client_context):
async def test_handshake(self, server_context, client_context):
async def server():
nonlocal server_binding
async with await stream_server.accept() as stream:
assert stream.server_side
assert stream.server_hostname is None
assert stream.tls_version.startswith('TLSv')
assert stream.cipher in stream.shared_ciphers
server_binding = stream.get_channel_binding()

command = await stream.receive_some(100)
await stream.send_all(command[::-1])

server_binding = None
async with create_task_group() as tg:
async with await create_tcp_server(
interface='localhost', ssl_context=server_context) as stream_server:
await tg.spawn(server)
async with await connect_tcp('localhost', stream_server.port,
tls=client_context) as client:
assert not client.server_side
assert client.server_hostname == 'localhost'
assert client.tls_version.startswith('TLSv')
assert client.cipher in client.shared_ciphers
client_binding = client.get_channel_binding()

await client.send_all(b'blah')
response = await client.receive_some(100)

assert response == b'halb'
assert client_binding == server_binding
assert isinstance(client_binding, bytes)

@pytest.mark.anyio
async def test_alpn_negotiation(self, server_context, client_context):
async def server():
async with await stream_server.accept() as stream:
assert stream.alpn_protocol == 'dummy2'

client_context.set_alpn_protocols(['dummy1', 'dummy2'])
server_context.set_alpn_protocols(['dummy2', 'dummy3'])
async with await create_tcp_server(
interface='localhost', ssl_context=server_context) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
async with await connect_tcp('localhost', stream_server.port,
tls=client_context) as client:
assert client.alpn_protocol == 'dummy2'


class TestUDPSocket:
Expand Down

0 comments on commit 304b3c5

Please sign in to comment.