Skip to content

Commit

Permalink
Merge pull request #79 from HyperionGray/delay_connection_closed
Browse files Browse the repository at this point in the history
Delay connection closed (#69)
  • Loading branch information
mehaase committed Nov 10, 2018
2 parents d6a238d + 1ca82e6 commit 64822ca
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 28 deletions.
68 changes: 66 additions & 2 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,70 @@ async def handler(request):
with trio.fail_after(2):
async with open_websocket(
HOST, server.port, '/', use_ssl=False) as connection:
with pytest.raises(ConnectionClosed) as e:
with pytest.raises(ConnectionClosed) as exc_info:
await connection.get_message()
assert e.reason.name == 'NORMAL_CLOSURE'
exc = exc_info.value
assert exc.reason.name == 'NORMAL_CLOSURE'


@pytest.mark.skip(reason='Hangs because channel size is hard coded to 0')
async def test_read_messages_after_remote_close(nursery):
'''
When the remote endpoint closes, the local endpoint can still read all
of the messages sent prior to closing. Any attempt to read beyond that will
raise ConnectionClosed.
'''
server_closed = trio.Event()

async def handler(request):
server = await request.accept()
async with server:
await server.send_message('1')
await server.send_message('2')
server_closed.set()

server = await nursery.start(
partial(serve_websocket, handler, HOST, 0, ssl_context=None))

async with open_websocket(HOST, server.port, '/', use_ssl=False) as client:
await server_closed.wait()
assert await client.get_message() == '1'
assert await client.get_message() == '2'
with pytest.raises(ConnectionClosed):
await client.get_message()


async def test_no_messages_after_local_close(nursery):
'''
If the local endpoint initiates closing, then pending messages are discarded
and any attempt to read a message will raise ConnectionClosed.
'''
client_closed = trio.Event()

async def handler(request):
# The server sends some messages and then closes.
server = await request.accept()
async with server:
await server.send_message('1')
await server.send_message('2')
await client_closed.wait()

server = await nursery.start(
partial(serve_websocket, handler, HOST, 0, ssl_context=None))

async with open_websocket(HOST, server.port, '/', use_ssl=False) as client:
pass
with pytest.raises(ConnectionClosed):
await client.get_message()
client_closed.set()


async def test_client_cm_exit_with_pending_messages(echo_server, autojump_clock):
with trio.fail_after(1):
async with open_websocket(HOST, echo_server.port, RESOURCE,
use_ssl=False) as ws:
await ws.send_message('hello')
# allow time for the server to respond
await trio.sleep(.1)
# bug: context manager exit is blocked on unconsumed message
#await ws.get_message()
54 changes: 28 additions & 26 deletions trio_websocket/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import trio.ssl
import wsproto.connection as wsconnection
import wsproto.frame_protocol as wsframeproto
from wsproto.events import BytesReceived
from yarl import URL

from .version import __version__
Expand Down Expand Up @@ -439,8 +440,7 @@ def __init__(self, stream, wsproto, *, path=None):
self._stream = stream
self._stream_lock = trio.StrictFIFOLock()
self._wsproto = wsproto
self._bytes_message = b''
self._str_message = ''
self._message_parts = [] # type: List[bytes|str]
self._reader_running = True
self._path = path
self._subprotocol = None
Expand Down Expand Up @@ -514,6 +514,7 @@ async def aclose(self, code=1000, reason=None):
return
self._wsproto.close(code=code, reason=reason)
try:
await self._recv_channel.aclose()
await self._write_pending()
await self._close_handshake.wait()
finally:
Expand All @@ -526,17 +527,21 @@ async def get_message(self):
Receive the next WebSocket message.
If no message is available immediately, then this function blocks until
a message is ready. When the connection is closed, this message
a message is ready.
If the remote endpoint closes the connection, then the caller can still
get messages sent prior to closing. Once all pending messages have been
retrieved, additional calls to this method will raise
``ConnectionClosed``. If the local endpoint closes the connection, then
pending messages are discarded and calls to this method will immediately
raise ``ConnectionClosed``.
:rtype: str or bytes
:raises ConnectionClosed: if connection is closed before a message
arrives.
:raises ConnectionClosed: if the connection is closed.
'''
if self._close_reason:
raise ConnectionClosed(self._close_reason)
try:
message = await self._recv_channel.receive()
except trio.EndOfChannel:
except (trio.ClosedResourceError, trio.EndOfChannel):
raise ConnectionClosed(self._close_reason) from None
return message

Expand Down Expand Up @@ -714,27 +719,24 @@ async def _handle_connection_failed_event(self, event):
self._open_handshake.set()
self._close_handshake.set()

async def _handle_bytes_received_event(self, event):
'''
Handle a BytesReceived event.
:param event:
async def _handle_data_received_event(self, event):
'''
self._bytes_message += event.data
if event.message_finished:
await self._send_channel.send(self._bytes_message)
self._bytes_message = b''

async def _handle_text_received_event(self, event):
'''
Handle a TextReceived event.
Handle a BytesReceived or TextReceived event.
:param event:
'''
self._str_message += event.data
self._message_parts.append(event.data)
if event.message_finished:
await self._send_channel.send(self._str_message)
self._str_message = ''
msg = (b'' if isinstance(event, BytesReceived) else '') \
.join(self._message_parts)
self._message_parts = []
try:
await self._send_channel.send(msg)
except trio.BrokenResourceError:
# The receive channel is closed, probably because somebody
# called ``aclose()``. We don't want to abort the reader task,
# and there's no useful cleanup that we can do here.
pass

async def _handle_ping_received_event(self, event):
'''
Expand Down Expand Up @@ -784,8 +786,8 @@ async def _reader_task(self):
'ConnectionFailed': self._handle_connection_failed_event,
'ConnectionEstablished': self._handle_connection_established_event,
'ConnectionClosed': self._handle_connection_closed_event,
'BytesReceived': self._handle_bytes_received_event,
'TextReceived': self._handle_text_received_event,
'BytesReceived': self._handle_data_received_event,
'TextReceived': self._handle_data_received_event,
'PingReceived': self._handle_ping_received_event,
'PongReceived': self._handle_pong_received_event,
}
Expand Down

0 comments on commit 64822ca

Please sign in to comment.