Skip to content

Commit

Permalink
Merge pull request #109 from HyperionGray/endpoint_info
Browse files Browse the repository at this point in the history
Add metadata for WebSocket endpoints (closes #108)
  • Loading branch information
mehaase committed May 15, 2019
2 parents 218124b + a6f6975 commit 7a562ca
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 43 deletions.
14 changes: 13 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ Requests
.. autoattribute:: headers
.. autoattribute:: proposed_subprotocols
.. autoattribute:: subprotocol
.. autoattribute:: url
.. autoattribute:: local
.. autoattribute:: remote
.. automethod:: accept

Connections
Expand All @@ -44,6 +45,8 @@ Connections
.. autoattribute:: closed
.. autoattribute:: is_client
.. autoattribute:: is_server
.. autoattribute:: local
.. autoattribute:: remote
.. autoattribute:: path
.. autoattribute:: subprotocol

Expand Down Expand Up @@ -77,3 +80,12 @@ Connections
:members:

.. autoexception:: ConnectionClosed

Utilities
---------

These are classes that you do not need to instantiate yourself, but you may
get access to instances of these classes through other APIs.

.. autoclass:: trio_websocket._impl.Endpoint
:members:
73 changes: 60 additions & 13 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import trustme
from async_generator import async_generator, yield_

from trio.testing import memory_stream_pair
from trio_websocket import (
connect_websocket,
connect_websocket_url,
Expand All @@ -49,7 +50,7 @@
wrap_client_stream,
wrap_server_stream
)
from trio_websocket._impl import ListenPort
from trio_websocket._impl import Endpoint


HOST = '127.0.0.1'
Expand Down Expand Up @@ -142,24 +143,36 @@ async def aclose(self):
await trio.hazmat.checkpoint()


async def test_listen_port_ipv4():
assert str(ListenPort('10.105.0.2', 80, False)) == 'ws://10.105.0.2:80'
assert str(ListenPort('127.0.0.1', 8000, False)) == 'ws://127.0.0.1:8000'
assert str(ListenPort('0.0.0.0', 443, True)) == 'wss://0.0.0.0:443'
async def test_endpoint_ipv4():
e1 = Endpoint('10.105.0.2', 80, False)
assert e1.url == 'ws://10.105.0.2'
assert str(e1) == 'Endpoint(address="10.105.0.2", port=80, is_ssl=False)'
e2 = Endpoint('127.0.0.1', 8000, False)
assert e2.url == 'ws://127.0.0.1:8000'
assert str(e2) == 'Endpoint(address="127.0.0.1", port=8000, is_ssl=False)'
e3 = Endpoint('0.0.0.0', 443, True)
assert e3.url == 'wss://0.0.0.0'
assert str(e3) == 'Endpoint(address="0.0.0.0", port=443, is_ssl=True)'


async def test_listen_port_ipv6():
assert str(ListenPort('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False)) \
== 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]:80'
assert str(ListenPort('::1', 8000, False)) == 'ws://[::1]:8000'
assert str(ListenPort('::', 443, True)) == 'wss://[::]:443'
e1 = Endpoint('2599:8807:6201:b7:16cf:bb9c:a6d3:51ab', 80, False)
assert e1.url == 'ws://[2599:8807:6201:b7:16cf:bb9c:a6d3:51ab]'
assert str(e1) == 'Endpoint(address="2599:8807:6201:b7:16cf:bb9c:a6d3' \
':51ab", port=80, is_ssl=False)'
e2 = Endpoint('::1', 8000, False)
assert e2.url == 'ws://[::1]:8000'
assert str(e2) == 'Endpoint(address="::1", port=8000, is_ssl=False)'
e3 = Endpoint('::', 443, True)
assert e3.url == 'wss://[::]'
assert str(e3) == 'Endpoint(address="::", port=443, is_ssl=True)'


async def test_server_has_listeners(nursery):
server = await nursery.start(serve_websocket, echo_request_handler, HOST, 0,
None)
assert len(server.listeners) > 0
assert isinstance(server.listeners[0], ListenPort)
assert isinstance(server.listeners[0], Endpoint)


async def test_serve(nursery):
Expand Down Expand Up @@ -192,6 +205,8 @@ async def test_serve_ssl(nursery):
async with open_websocket(HOST, port, RESOURCE, use_ssl=client_context
) as conn:
assert not conn.closed
assert conn.local.is_ssl
assert conn.remote.is_ssl


async def test_serve_handler_nursery(nursery):
Expand Down Expand Up @@ -277,6 +292,35 @@ async def test_client_connect_url(echo_server, nursery):
assert not conn.closed


async def test_connection_has_endpoints(echo_conn):
async with echo_conn:
assert isinstance(echo_conn.local, Endpoint)
assert str(echo_conn.local.address) == HOST
assert echo_conn.local.port > 1024
assert not echo_conn.local.is_ssl

assert isinstance(echo_conn.remote, Endpoint)
assert str(echo_conn.remote.address) == HOST
assert echo_conn.remote.port > 1024
assert not echo_conn.remote.is_ssl


@fail_after(1)
async def test_handshake_has_endpoints(nursery):
async def handler(request):
assert str(request.local.address) == HOST
assert request.local.port == server.port
assert not request.local.is_ssl
assert str(request.remote.address) == HOST
assert not request.remote.is_ssl
conn = await request.accept()

server = await nursery.start(serve_websocket, handler, HOST, 0, None)
async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False
) as client_ws:
pass


async def test_handshake_subprotocol(nursery):
async def handler(request):
assert request.proposed_subprotocols == ('chat', 'file')
Expand All @@ -299,7 +343,6 @@ async def test_client_send_and_receive(echo_conn):
assert received_msg == 'This is a test message.'



async def test_client_send_invalid_type(echo_conn):
async with echo_conn:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -366,14 +409,18 @@ async def test_client_nondefault_close(echo_conn):
assert echo_conn.closed.reason == 'test reason'


async def test_wrap_client_stream(echo_server, nursery):
stream = await trio.open_tcp_stream(HOST, echo_server.port)
async def test_wrap_client_stream(nursery):
listener = MemoryListener()
server = WebSocketServer(echo_request_handler, [listener])
await nursery.start(server.run)
stream = await listener.connect()
conn = await wrap_client_stream(nursery, stream, HOST, RESOURCE)
async with conn:
assert not conn.closed
await conn.send_message('Hello from client!')
msg = await conn.get_message()
assert msg == 'Hello from client!'
assert conn.local.startswith('StapledStream(')
assert conn.closed


Expand Down
141 changes: 112 additions & 29 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,15 +432,14 @@ class WebSocketRequest:
The server may modify the handshake or leave it as is. The server should
call ``accept()`` to finish the handshake and obtain a connection object.
'''
def __init__(self, accept_fn, event):
def __init__(self, connection, event):
'''
Constructor.
:param accept_fn: A function to call that will finish the handshake and
return a ``WebSocketConnection``.
:param WebSocketConnection connection:
:type event: wsproto.events.Request
'''
self._accept_fn = accept_fn
self._connection = connection
self._event = event
self._subprotocol = None

Expand Down Expand Up @@ -490,14 +489,59 @@ def subprotocol(self, value):
'''
self._subprotocol = value

@property
def local(self):
'''
The connection's local endpoint.
:rtype: Endpoint or str
'''
return self._connection.local

@property
def remote(self):
'''
The connection's remote endpoint.
:rtype: Endpoint or str
'''
return self._connection.remote

async def accept(self):
'''
Finish the handshake with the terms contained in this request and
return a connection object.
:rtype: WebSocketConnection
'''
return await self._accept_fn(self)
await self._connection.accept(self._event, self._subprotocol)
return self._connection


def _get_stream_endpoint(stream, *, local):
'''
Construct an endpoint from a stream.
:param trio.Stream stream:
:param bool local: If true, return local endpoint. Otherwise return remote.
:returns: An endpoint instance or ``repr()`` for streams that cannot be
represented as an endpoint.
:rtype: Endpoint or str
'''
if isinstance(stream, trio.SocketStream):
socket = stream.socket
is_ssl = False
elif isinstance(stream, trio.SSLStream):
socket = stream.transport_stream.socket
is_ssl = True
else:
socket = None
if socket:
addr, port, *_ = socket.getsockname() if local else socket.getpeername()
endpoint = Endpoint(addr, port, is_ssl)
else:
endpoint = repr(stream)
return endpoint


class WebSocketConnection(trio.abc.AsyncResource):
Expand Down Expand Up @@ -584,6 +628,24 @@ def is_server(self):
''' (Read-only) Is this a server instance? '''
return not self._wsproto.client

@property
def local(self):
'''
The local endpoint of the connection.
:rtype: Endpoint or str
'''
return _get_stream_endpoint(self._stream, local=True)

@property
def remote(self):
'''
The remote endpoint of the connection.
:rtype: Endpoint or str
'''
return _get_stream_endpoint(self._stream, local=False)

@property
def path(self):
''' (Read-only) The path from the HTTP handshake. '''
Expand All @@ -601,6 +663,24 @@ def subprotocol(self):
'''
return self._subprotocol

async def accept(self, request, subprotocol):
'''
Accept a connection request.
This finishes the server-side handshake with the given proposal
attributes and return the connection instance. Generally you don't need
to call this method directly. It is invoked for you when you accept a
:class:`WebSocketRequest`.
:param wsproto.events.Request request:
:param subprotocol:
:type subprotocol: str or None
'''
self._subprotocol = subprotocol
self._path = request.target
await self._send(AcceptConnection(subprotocol=self._subprotocol))
self._open_handshake.set()

async def aclose(self, code=1000, reason=None):
'''
Close the WebSocket connection.
Expand Down Expand Up @@ -742,21 +822,6 @@ async def _abort_web_socket(self):
# (e.g. self.aclose()) to resume.
self._close_handshake.set()

async def _accept(self, proposal):
'''
Accept a given proposal.
This finishes the server-side handshake with the given proposal
attributes and return the connection instance.
:rtype: WebSocketConnection
'''
self._subprotocol = proposal.subprotocol
self._path = proposal.path
await self._send(AcceptConnection(subprotocol=self._subprotocol))
self._open_handshake.set()
return self

async def _close_stream(self):
''' Close the TCP connection. '''
self._reader_running = False
Expand Down Expand Up @@ -804,7 +869,7 @@ async def _handle_request_event(self, event):
:param event:
'''
proposal = WebSocketRequest(self._accept, event)
proposal = WebSocketRequest(self, event)
self._connection_proposal.set_value(proposal)

async def _handle_accept_connection_event(self, event):
Expand Down Expand Up @@ -977,20 +1042,35 @@ async def _send(self, event):
raise ConnectionClosed(self._close_reason) from None


class ListenPort:
''' Represents a listener on a given address and port. '''
class Endpoint:
''' Represents a connection endpoint. '''
def __init__(self, address, port, is_ssl):
#: IP address :class:`ipaddress.ip_address`
self.address = ip_address(address)
#: TCP port
self.port = port
#: Whether SSL is in use
self.is_ssl = is_ssl

def __str__(self):
''' Return a compact representation, like 127.0.0.1:80 or [::1]:80. '''
@property
def url(self):
''' Return a URL representation of a TCP endpoint, e.g.
``ws://127.0.0.1:80``. '''
scheme = 'wss' if self.is_ssl else 'ws'
if (self.port == 80 and not self.is_ssl) or \
(self.port == 443 and self.is_ssl):
port_str = ''
else:
port_str = ':' + str(self.port)
if self.address.version == 4:
return '{}://{}:{}'.format(scheme, self.address, self.port)
return '{}://{}{}'.format(scheme, self.address, port_str)
else:
return '{}://[{}]:{}'.format(scheme, self.address, self.port)
return '{}://[{}]{}'.format(scheme, self.address, port_str)

def __repr__(self):
''' Return endpoint info as string. '''
return 'Endpoint(address="{}", port={}, is_ssl={})'.format(self.address,
self.port, self.is_ssl)


class WebSocketServer:
Expand Down Expand Up @@ -1062,8 +1142,11 @@ def port(self):
def listeners(self):
'''
Return a list of listener metadata. Each TCP listener is represented as
a ``ListenPort`` instance. Other listener types are represented by their
an ``Endpoint`` instance. Other listener types are represented by their
``repr()``.
:returns: Listeners
:rtype list[Endpoint or str]:
'''
listeners = list()
for listener in self._listeners:
Expand All @@ -1077,7 +1160,7 @@ def listeners(self):
socket = None
if socket:
sockname = socket.getsockname()
listeners.append(ListenPort(sockname[0], sockname[1], is_ssl))
listeners.append(Endpoint(sockname[0], sockname[1], is_ssl))
else:
listeners.append(repr(listener))
return listeners
Expand Down

0 comments on commit 7a562ca

Please sign in to comment.