Skip to content

Commit

Permalink
Make it possible to fragment outgoing messages.
Browse files Browse the repository at this point in the history
Fix #258.
  • Loading branch information
aaugustin committed Sep 23, 2018
1 parent 4e7a82e commit 84ce48c
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 26 deletions.
2 changes: 2 additions & 0 deletions docs/changelog.rst
Expand Up @@ -31,6 +31,8 @@ Changelog
* websockets sends Ping frames at regular intervals and closes the connection
if it doesn't receive a matching Pong frame. See :class:`~protocol.WebSocketCommonProtocol` for details.

* Added support for sending fragmented messages.

* Added the :meth:`~protocol.WebSocketCommonProtocol.wait_closed` method to
protocols.

Expand Down
73 changes: 61 additions & 12 deletions src/websockets/protocol.py
Expand Up @@ -11,6 +11,7 @@
import binascii
import codecs
import collections
import collections.abc
import enum
import logging
import random
Expand Down Expand Up @@ -428,20 +429,66 @@ def send(self, data):
This coroutine sends a message.
It sends :class:`str` as a text frame and :class:`bytes` as a binary
frame. It raises a :exc:`TypeError` for other inputs.
frame.
It also accepts an iterable of :class:`str` or :class:`bytes`. Each
item is treated as a message fragment and sent in its own frame. All
items must be of the same type, or else :meth:`send` will raise a
:exc:`TypeError` and the connection will be closed.
It raises a :exc:`TypeError` for other inputs.
"""
yield from self.ensure_open()

# Unfragmented message (first because str and bytes are iterable).

if isinstance(data, str):
opcode = 1
data = data.encode('utf-8')
yield from self.write_frame(True, OP_TEXT, data.encode('utf-8'))

elif isinstance(data, bytes):
opcode = 2
else:
raise TypeError("data must be bytes or str")
yield from self.write_frame(True, OP_BINARY, data)

yield from self.write_frame(opcode, data)
# Fragmented message -- regular iterator.

elif isinstance(data, collections.abc.Iterable):
iter_data = iter(data)

# First fragment.
try:
data = next(iter_data)
except StopIteration:
return
data_type = type(data)
if isinstance(data, str):
yield from self.write_frame(False, OP_TEXT, data.encode('utf-8'))
encode_data = True
elif isinstance(data, bytes):
yield from self.write_frame(False, OP_BINARY, data)
encode_data = False
else:
raise TypeError("data must be an iterable of bytes or str")

# Other fragments.
for data in iter_data:
if type(data) != data_type:
# We're half-way through a fragmented message and we can't
# complete it. This makes the connection unusable.
self.fail_connection(1011)
raise TypeError("data contains inconsistent types")
if encode_data:
data = data.encode('utf-8')
yield from self.write_frame(False, OP_CONT, data)

# Final fragment.
yield from self.write_frame(True, OP_CONT, type(data)())

# Fragmented message -- asynchronous iterator

# To be implemented after dropping support for Python 3.4.

else:
raise TypeError("data must be bytes, str, or iterable")

@asyncio.coroutine
def close(self, code=1000, reason=''):
Expand Down Expand Up @@ -529,7 +576,7 @@ def ping(self, data=None):

self.pings[data] = asyncio.Future(loop=self.loop)

yield from self.write_frame(OP_PING, data)
yield from self.write_frame(True, OP_PING, data)

return asyncio.shield(self.pings[data])

Expand All @@ -549,7 +596,7 @@ def pong(self, data=b''):

data = encode_data(data)

yield from self.write_frame(OP_PONG, data)
yield from self.write_frame(True, OP_PONG, data)

# Private methods - no guarantees.

Expand Down Expand Up @@ -803,14 +850,14 @@ def read_frame(self, max_size):
return frame

@asyncio.coroutine
def write_frame(self, opcode, data=b'', _expected_state=State.OPEN):
def write_frame(self, fin, opcode, data, *, _expected_state=State.OPEN):
# Defensive assertion for protocol compliance.
if self.state is not _expected_state: # pragma: no cover
raise InvalidState(
"Cannot write to a WebSocket " "in the {} state".format(self.state.name)
)

frame = Frame(True, opcode, data)
frame = Frame(fin, opcode, data)
logger.debug("%s > %s", self.side, frame)
frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions)

Expand Down Expand Up @@ -870,7 +917,9 @@ def write_close_frame(self, data=b''):
logger.debug("%s - state = CLOSING", self.side)

# 7.1.2. Start the WebSocket Closing Handshake
yield from self.write_frame(OP_CLOSE, data, State.CLOSING)
yield from self.write_frame(
True, OP_CLOSE, data, _expected_state=State.CLOSING
)

@asyncio.coroutine
def keepalive_ping(self):
Expand Down
68 changes: 54 additions & 14 deletions tests/test_protocol.py
Expand Up @@ -255,12 +255,9 @@ def process_invalid_frames(self):
self.receive_eof()
self.loop.run_until_complete(self.protocol.close_connection_task)

def last_sent_frame(self):
def sent_frames(self):
"""
Read the last frame sent to the transport.
This method assumes that at most one frame was sent. It raises an
AssertionError otherwise.
Read all frames sent to the transport.
"""
stream = asyncio.StreamReader(loop=self.loop)
Expand All @@ -270,18 +267,30 @@ def last_sent_frame(self):
self.transport.write.call_args_list = []
stream.feed_eof()

if stream.at_eof():
frame = None
else:
frame = self.loop.run_until_complete(
Frame.read(stream.readexactly, mask=self.protocol.is_client)
frames = []
while not stream.at_eof():
frames.append(
self.loop.run_until_complete(
Frame.read(stream.readexactly, mask=self.protocol.is_client)
)
)
return frames

def last_sent_frame(self):
"""
Read the last frame sent to the transport.
This method assumes that at most one frame was sent. It raises an
AssertionError otherwise.
if not stream.at_eof(): # pragma: no cover
data = self.loop.run_until_complete(stream.read())
raise AssertionError("Trailing data found: {!r}".format(data))
"""
frames = self.sent_frames()
if frames:
assert len(frames) == 1
return frames[0]

return frame
def assertFramesSent(self, *frames):
self.assertEqual(self.sent_frames(), [Frame(*args) for args in frames])

def assertOneFrameSent(self, *args):
self.assertEqual(self.last_sent_frame(), Frame(*args))
Expand Down Expand Up @@ -467,6 +476,37 @@ def test_send_type_error(self):
self.loop.run_until_complete(self.protocol.send(42))
self.assertNoFrameSent()

def test_send_iterable_text(self):
self.loop.run_until_complete(self.protocol.send(['ca', 'fé']))
self.assertFramesSent(
(False, OP_TEXT, 'ca'.encode('utf-8')),
(False, OP_CONT, 'fé'.encode('utf-8')),
(True, OP_CONT, ''.encode('utf-8')),
)

def test_send_iterable_binary(self):
self.loop.run_until_complete(self.protocol.send([b'te', b'a']))
self.assertFramesSent(
(False, OP_BINARY, b'te'), (False, OP_CONT, b'a'), (True, OP_CONT, b'')
)

def test_send_empty_iterable(self):
self.loop.run_until_complete(self.protocol.send([]))
self.assertNoFrameSent()

def test_send_iterable_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.send([42]))
self.assertNoFrameSent()

def test_send_iterable_mixed_type_error(self):
with self.assertRaises(TypeError):
self.loop.run_until_complete(self.protocol.send(['café', b'tea']))
self.assertFramesSent(
(False, OP_TEXT, 'café'.encode('utf-8')),
(True, OP_CLOSE, serialize_close(1011, '')),
)

def test_send_on_closing_connection_local(self):
close_task = self.half_close_connection_local()

Expand Down

0 comments on commit 84ce48c

Please sign in to comment.