Skip to content

Commit

Permalink
Merge 8e13b26 into e971e7b
Browse files Browse the repository at this point in the history
  • Loading branch information
mehaase committed Oct 17, 2018
2 parents e971e7b + 8e13b26 commit afd2853
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 12 deletions.
73 changes: 73 additions & 0 deletions README.md
Expand Up @@ -78,6 +78,79 @@ trio.run(main)
A longer example is in `examples/server.py`. **See the note above about using
SSL with the example client.**

## Heartbeat recipe

If you wish to keep a connection open for long periods of time but do not need
to send messages frequently, then a heartbeat holds the connection open and also
detects when the connection drops unexpectedly. The following recipe
demonstrates how to implement a connection heartbeat using WebSocket's ping/pong
feature.

```python
async def heartbeat(ws, timeout, interval):
'''
Send periodic pings on WebSocket ``ws``.
Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises
``TooSlowError`` if the timeout is exceeded. If a pong is received, then
wait ``interval`` seconds before sending the next ping.
This function runs until cancelled.
:param ws: A WebSocket to send heartbeat pings on.
:param float timeout: Timeout in seconds.
:param float interval: Interval between receiving pong and sending next
ping, in seconds.
:raises: ``ConnectionClosed`` if ``ws`` is closed.
:raises: ``TooSlowError`` if the timeout expires.
:returns: This function runs until cancelled.
'''
while True:
with trio.fail_after(timeout):
await ws.ping()
await trio.sleep(interval)

async def main():
async with open_websocket_url('ws://localhost/foo') as ws:
async with trio.open_nursery() as nursery:
nursery.start_soon(heartbeat, ws, 5, 1)
# Your application code goes here:
pass

trio.run(main)
```

Note that the `ping()` method waits until it receives a pong frame, so it
ensures that the remote endpoint is still responsive. If the connection is
dropped unexpectedly or takes too long to respond, then `heartbeat()` will raise
an exception that will cancel the nursery. You may wish to implement additional
logic to automatically reconnect.

A heartbeat feature can be enabled in the example client with the
``--heartbeat`` flag.

**Note that the WebSocket RFC does not require a WebSocket to send a pong for each
ping:**

> If an endpoint receives a Ping frame and has not yet sent Pong frame(s) in
> response to previous Ping frame(s), the endpoint MAY elect to send a Pong
> frame for only the most recently processed Ping frame.
Therefore, if you have multiple pings in flight at the same time, you may not
get an equal number of pongs in response. The simplest strategy for dealing with
this is to only have one ping in flight at a time, as seen in the example above.
As an alternative, you can send a `bytes` payload with each ping. The server
will return the payload with the pong:

```python
await ws.ping(b'my payload')
pong == await ws.wait_pong()
assert pong == b'my payload'
```

You may want to embed a nonce or counter in the payload in order to correlate
pong events to the pings you have sent.

## Unit Tests

Unit tests are written in the pytest style. You must install the development
Expand Down
32 changes: 30 additions & 2 deletions examples/client.py
Expand Up @@ -32,6 +32,8 @@ def commands():
def parse_args():
''' Parse command line arguments. '''
parser = argparse.ArgumentParser(description='Example trio-websocket client')
parser.add_argument('--heartbeat', action='store_true',
help='Create a heartbeat task')
parser.add_argument('url', help='WebSocket URL to connect to')
return parser.parse_args()

Expand All @@ -53,17 +55,19 @@ async def main(args):
try:
logging.debug('Connecting to WebSocket…')
async with open_websocket_url(args.url, ssl_context) as conn:
await handle_connection(conn)
await handle_connection(conn, args.heartbeat)
except OSError as ose:
logging.error('Connection attempt failed: %s', ose)
return False


async def handle_connection(ws):
async def handle_connection(ws, use_heartbeat):
''' Handle the connection. '''
logging.debug('Connected!')
try:
async with trio.open_nursery() as nursery:
if use_heartbeat:
nursery.start_soon(heartbeat, ws, 1, 15)
nursery.start_soon(get_commands, ws)
nursery.start_soon(get_messages, ws)
except ConnectionClosed as cc:
Expand All @@ -72,6 +76,30 @@ async def handle_connection(ws):
print('Closed: {}/{} {}'.format(cc.reason.code, cc.reason.name, reason))


async def heartbeat(ws, timeout, interval):
'''
Send periodic pings on WebSocket ``ws``.
Wait up to ``timeout`` seconds to send a ping and receive a pong. Raises
``TooSlowError`` if the timeout is exceeded. If a pong is received, then
wait ``interval`` seconds before sending the next ping.
This function runs until cancelled.
:param ws: A WebSocket to send heartbeat pings on.
:param float timeout: Timeout in seconds.
:param float interval: Interval between receiving pong and sending next
ping, in seconds.
:raises: ``ConnectionClosed`` if ``ws`` is closed.
:raises: ``TooSlowError`` if the timeout expires.
:returns: This function runs until cancelled.
'''
while True:
with trio.fail_after(timeout):
await ws.ping()
await trio.sleep(interval)


async def get_commands(ws):
''' In a loop: get a command from the user and execute it. '''
while True:
Expand Down
45 changes: 45 additions & 0 deletions tests/test_connection.py
Expand Up @@ -207,6 +207,51 @@ async def test_client_send_and_receive(echo_conn):
assert received_msg == 'This is a test message.'


async def test_client_ping(echo_conn):
async with echo_conn:
await echo_conn.ping(b'A')
with pytest.raises(ConnectionClosed):
await echo_conn.ping(b'B')


async def test_client_ping_two_payloads(echo_conn):
pong_count = 0
async def ping_and_count():
nonlocal pong_count
await echo_conn.ping()
pong_count += 1
async with echo_conn:
async with trio.open_nursery() as nursery:
nursery.start_soon(ping_and_count)
nursery.start_soon(ping_and_count)
assert pong_count == 2


async def test_client_ping_same_payload(echo_conn):
# This test verifies that two tasks can't ping with the same payload at the
# same time. One of them should succeed and the other should get an
# exception.
exc_count = 0
async def ping_and_catch():
nonlocal exc_count
try:
await echo_conn.ping(b'A')
except Exception:
exc_count += 1
async with echo_conn:
async with trio.open_nursery() as nursery:
nursery.start_soon(ping_and_catch)
nursery.start_soon(ping_and_catch)
assert exc_count == 1


async def test_client_pong(echo_conn):
async with echo_conn:
await echo_conn.pong(b'A')
with pytest.raises(ConnectionClosed):
await echo_conn.pong(b'B')


async def test_client_default_close(echo_conn):
async with echo_conn:
assert not echo_conn.is_closed
Expand Down
68 changes: 58 additions & 10 deletions trio_websocket/__init__.py
@@ -1,7 +1,10 @@
from collections import OrderedDict
from functools import partial
import itertools
import logging
import random
import ssl
from functools import partial
import struct

from async_generator import async_generator, yield_, asynccontextmanager
import attr
Expand Down Expand Up @@ -320,6 +323,7 @@ def __init__(self, stream, wsproto, path=None):
self._reader_running = True
self._path = path
self._put_channel, self._get_channel = open_channel(0)
self._pings = OrderedDict()
# Set once the WebSocket open handshake takes place, i.e.
# ConnectionRequested for server or ConnectedEstablished for client.
self._open_handshake = trio.Event()
Expand Down Expand Up @@ -398,13 +402,38 @@ async def get_message(self):
raise ConnectionClosed(self._close_reason) from None
return message

async def ping(self, payload):
async def ping(self, payload=None):
'''
Send WebSocket ping to peer.
Send WebSocket ping to peer and wait for a correspoding pong.
Does not wait for pong reply. (Is this the right behavior? This may
change in the future.) Raises ``ConnectionClosed`` if the connection is
closed.
Each ping is matched to its expected pong by its payload value. An
exception is raised if you call ping with a ``payload`` value equal to
an existing in-flight ping. If the remote endpoint recieves multiple
pings, it is allowed to send a single pong. Therefore, the order of
calls to ``ping()`` is tracked, and a pong will wake up its
corresponding ping _as well as any earlier pings_.
:param payload: The payload to send. If ``None`` then a random value is
created.
:type payload: str, bytes, or None
:raises ConnectionClosed: if connection is closed
'''
if self._close_reason:
raise ConnectionClosed(self._close_reason)
if payload in self._pings:
raise Exception('Payload value {} is already in flight.'.
format(payload))
if payload is None:
payload = struct.pack('!I', random.getrandbits(32))
event = trio.Event()
self._pings[payload] = event
self._wsproto.ping(payload)
await self._write_pending()
await event.wait()

async def pong(self, payload=None):
'''
Send an unsolicted pong.
:param payload: str or bytes payloads
:raises ConnectionClosed: if connection is closed
Expand Down Expand Up @@ -537,18 +566,37 @@ async def _handle_ping_received_event(self, event):
:param event:
'''
logger.debug('conn#%d ping %r', self._id, event.payload)
await self._write_pending()

async def _handle_pong_received_event(self, event):
'''
Handle a PongReceived event.
Currently we don't do anything special for a Pong frame, but this may
change in the future. This handler is here as a placeholder.
When a pong is received, check if we have any ping requests waiting for
this pong response. If the remote endpoint skipped any earlier pings,
then we wake up those skipped pings, too.
This function is async even though it never awaits, because the other
event handlers are async, too, and event dispatch would be more
complicated if some handlers were sync.
:param event:
'''
logger.debug('conn#%d pong %r', self._id, event.payload)
payload = bytes(event.payload)
try:
event = self._pings[payload]
except KeyError:
# We received a pong that doesn't match any in-flight pongs. Nothing
# we can do with it, so ignore it.
return
while self._pings:
key, event = self._pings.popitem(0)
skipped = ' [skipped] ' if payload != key else ' '
logger.debug('conn#%d pong%s%r', self._id, skipped, key)
event.set()
if payload == key:
break

async def _reader_task(self):
''' A background task that reads network data and generates events. '''
Expand Down Expand Up @@ -577,7 +625,7 @@ async def _reader_task(self):
event_type)
await handler(event)
except KeyError:
logger.warning('Received unknown event type: %s',
logger.warning('Received unknown event type: "%s"',
event_type)

# Get network data.
Expand Down

0 comments on commit afd2853

Please sign in to comment.