Skip to content

Commit

Permalink
Improve error handling in broadcast().
Browse files Browse the repository at this point in the history
Fix #1319.
  • Loading branch information
aaugustin committed Apr 1, 2023
1 parent ce06dd6 commit 2fcc483
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 26 deletions.
2 changes: 2 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ New features
Improvements
............

* Improved error handling in :func:`~websockets.broadcast`.

* Set ``server_hostname`` automatically on TLS connections when providing a
``sock`` argument to :func:`~sync.client.connect`.

Expand Down
5 changes: 5 additions & 0 deletions docs/topics/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,11 @@ Here's what websockets logs at each level.
* Exceptions raised by connection handler coroutines in servers
* Exceptions resulting from bugs in websockets

``WARNING``
...........

* Failures in :func:`~websockets.broadcast`

``INFO``
........

Expand Down
74 changes: 54 additions & 20 deletions src/websockets/legacy/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import random
import ssl
import struct
import sys
import time
import uuid
import warnings
Expand Down Expand Up @@ -1573,54 +1574,87 @@ def eof_received(self) -> None:
self.reader.feed_eof()


def broadcast(websockets: Iterable[WebSocketCommonProtocol], message: Data) -> None:
def broadcast(
websockets: Iterable[WebSocketCommonProtocol],
message: Data,
raise_exceptions: bool = False,
) -> None:
"""
Broadcast a message to several WebSocket connections.
A string (:class:`str`) is sent as a Text_ frame. A bytestring or
bytes-like object (:class:`bytes`, :class:`bytearray`, or
:class:`memoryview`) is sent as a Binary_ frame.
A string (:class:`str`) is sent as a Text_ frame. A bytestring or bytes-like
object (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) is sent
as a Binary_ frame.
.. _Text: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
.. _Binary: https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
:func:`broadcast` pushes the message synchronously to all connections even
if their write buffers are overflowing. There's no backpressure.
:func:`broadcast` skips silently connections that aren't open in order to
avoid errors on connections where the closing handshake is in progress.
If you broadcast messages faster than a connection can handle them,
messages will pile up in its write buffer until the connection times out.
Keep low values for ``ping_interval`` and ``ping_timeout`` to prevent
excessive memory usage by slow connections when you use :func:`broadcast`.
If you broadcast messages faster than a connection can handle them, messages
will pile up in its write buffer until the connection times out. Keep
``ping_interval`` and ``ping_timeout`` low to prevent excessive memory usage
from slow connections.
Unlike :meth:`~websockets.server.WebSocketServerProtocol.send`,
:func:`broadcast` doesn't support sending fragmented messages. Indeed,
fragmentation is useful for sending large messages without buffering
them in memory, while :func:`broadcast` buffers one copy per connection
as fast as possible.
fragmentation is useful for sending large messages without buffering them in
memory, while :func:`broadcast` buffers one copy per connection as fast as
possible.
:func:`broadcast` skips connections that aren't open in order to avoid
errors on connections where the closing handshake is in progress.
:func:`broadcast` ignores failures to write the message on some connections.
It continues writing to other connections. On Python 3.11 and above, you
may set ``raise_exceptions`` to :obj:`True` to record failures and raise all
exceptions in a :pep:`654` :exc:`ExceptionGroup`.
Args:
websockets (Iterable[WebSocketCommonProtocol]): WebSocket connections
to which the message will be sent.
message (Data): Message to send.
websockets: WebSocket connections to which the message will be sent.
message: Message to send.
raise_exceptions: Whether to raise an exception in case of failures.
Raises:
RuntimeError: If a connection is busy sending a fragmented message.
TypeError: If ``message`` doesn't have a supported type.
"""
if not isinstance(message, (str, bytes, bytearray, memoryview)):
raise TypeError("data must be str or bytes-like")

if raise_exceptions:
if sys.version_info[:2] < (3, 11): # pragma: no cover
raise ValueError("raise_exceptions requires at least Python 3.11")
exceptions = []

opcode, data = prepare_data(message)

for websocket in websockets:
if websocket.state is not State.OPEN:
continue

if websocket._fragmented_message_waiter is not None:
raise RuntimeError("busy sending a fragmented message")
if raise_exceptions:
exception = RuntimeError("sending a fragmented message")
exceptions.append(exception)
else:
websocket.logger.warning(
"skipped broadcast: sending a fragmented message",
)

try:
websocket.write_frame_sync(True, opcode, data)
except Exception as write_exception:
if raise_exceptions:
exception = RuntimeError("failed to write message")
exception.__cause__ = write_exception
exceptions.append(exception)
else:
websocket.logger.warning(
"skipped broadcast: failed to write message",
exc_info=True,
)

websocket.write_frame_sync(True, opcode, data)
if raise_exceptions:
raise ExceptionGroup("skipped broadcast", exceptions)
4 changes: 2 additions & 2 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,11 +1553,11 @@ async def run_client():
await ws.recv()
else:
# Exit block with an exception.
raise Exception("BOOM!")
raise Exception("BOOM")
pass # work around bug in coverage

with self.assertLogs("websockets", logging.INFO) as logs:
with self.assertRaisesRegex(Exception, "BOOM!"):
with self.assertRaisesRegex(Exception, "BOOM"):
self.loop.run_until_complete(run_client())

# Iteration 1
Expand Down
60 changes: 56 additions & 4 deletions tests/legacy/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import contextlib
import logging
import sys
import unittest
import unittest.mock
import warnings
Expand Down Expand Up @@ -1468,26 +1470,76 @@ def test_broadcast_two_clients(self):
def test_broadcast_skips_closed_connection(self):
self.close_connection()

broadcast([self.protocol], "café")
with self.assertNoLogs():
broadcast([self.protocol], "café")
self.assertNoFrameSent()

def test_broadcast_skips_closing_connection(self):
close_task = self.half_close_connection_local()

broadcast([self.protocol], "café")
with self.assertNoLogs():
broadcast([self.protocol], "café")
self.assertNoFrameSent()

self.loop.run_until_complete(close_task) # cleanup

def test_broadcast_within_fragmented_text(self):
def test_broadcast_skips_connection_sending_fragmented_text(self):
self.make_drain_slow()
self.loop.create_task(self.protocol.send(["ca", "fé"]))
self.run_loop_once()
self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8"))

with self.assertRaises(RuntimeError):
with self.assertLogs("websockets", logging.WARNING) as logs:
broadcast([self.protocol], "café")

self.assertEqual(
[record.getMessage() for record in logs.records][:2],
["skipped broadcast: sending a fragmented message"],
)

@unittest.skipIf(
sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+"
)
def test_broadcast_reports_connection_sending_fragmented_text(self):
self.make_drain_slow()
self.loop.create_task(self.protocol.send(["ca", "fé"]))
self.run_loop_once()
self.assertOneFrameSent(False, OP_TEXT, "ca".encode("utf-8"))

with self.assertRaises(ExceptionGroup) as raised:
broadcast([self.protocol], "café", raise_exceptions=True)

self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)")
self.assertEqual(
str(raised.exception.exceptions[0]), "sending a fragmented message"
)

def test_broadcast_skips_connection_failing_to_send(self):
# Configure mock to raise an exception when writing to the network.
self.protocol.transport.write.side_effect = RuntimeError

with self.assertLogs("websockets", logging.WARNING) as logs:
broadcast([self.protocol], "café")

self.assertEqual(
[record.getMessage() for record in logs.records][:2],
["skipped broadcast: failed to write message"],
)

@unittest.skipIf(
sys.version_info[:2] < (3, 11), "raise_exceptions requires Python 3.11+"
)
def test_broadcast_reports_connection_failing_to_send(self):
# Configure mock to raise an exception when writing to the network.
self.protocol.transport.write.side_effect = RuntimeError("BOOM")

with self.assertRaises(ExceptionGroup) as raised:
broadcast([self.protocol], "café", raise_exceptions=True)

self.assertEqual(str(raised.exception), "skipped broadcast (1 sub-exception)")
self.assertEqual(str(raised.exception.exceptions[0]), "failed to write message")
self.assertEqual(str(raised.exception.exceptions[0].__cause__), "BOOM")


class ServerTests(CommonTests, AsyncioTestCase):
def setUp(self):
Expand Down

0 comments on commit 2fcc483

Please sign in to comment.