Skip to content

Commit

Permalink
Add API to customize subprotocol selection logic.
Browse files Browse the repository at this point in the history
Also changed the default logic:

(1) to reject client connections that don't offer a subprotocol when the
server is configured with subprotocols.

This is the expected behavior for what I believe to be the default use
case: require one particular subprotocol.

This change of behavior isn't documented because I don't know anyone
embedding the Sans-I/O layer and supporting subprotocol selection at
this time.

(2) to rely only on the order of preference of the server.

Hey, trying to cater to the preferences of clients was nice, but the
behavior was so needlessly complex that documentation apologized...

This keeps things simple and still falls within the documented behavior.
  • Loading branch information
aaugustin committed Dec 18, 2022
1 parent 1363dd5 commit 23a2d3f
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 57 deletions.
5 changes: 4 additions & 1 deletion docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Backwards-incompatible changes
:class: caution

Aliases provide compatibility for all previously public APIs according to
the `backwards-compatibility policy`_
the `backwards-compatibility policy`_.

* The ``connection`` module was renamed to ``protocol``.

Expand Down Expand Up @@ -67,6 +67,9 @@ New features

* Made it possible to close a server without closing existing connections.

* Added :attr:`~protocol.ServerProtocol.select_subprotocol` to customize
negotiation of subprotocols in the Sans-I/O layer.

10.4
----

Expand Down
2 changes: 2 additions & 0 deletions docs/reference/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ Sans-I/O

.. automethod:: accept

.. automethod:: select_subprotocol

.. automethod:: reject

.. automethod:: send_response
Expand Down
24 changes: 11 additions & 13 deletions src/websockets/legacy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,31 +520,29 @@ def select_subprotocol(
server_subprotocols: Sequence[Subprotocol],
) -> Optional[Subprotocol]:
"""
Pick a subprotocol among those offered by the client.
Pick a subprotocol among those supported by the client and the server.
If several subprotocols are supported by the client and the server,
the default implementation selects the preferred subprotocol by
giving equal value to the priorities of the client and the server.
If no subprotocol is supported by the client and the server, it
proceeds without a subprotocol.
If several subprotocols are available, select the preferred subprotocol
by giving equal weight to the preferences of the client and the server.
This is unlikely to be the most useful implementation in practice.
Many servers providing a subprotocol will require that the client
uses that subprotocol. Such rules can be implemented in a subclass.
If no subprotocol is available, proceed without a subprotocol.
You may also override this method with the ``select_subprotocol``
argument of :func:`serve` and :class:`WebSocketServerProtocol`.
You may provide a ``select_subprotocol`` argument to :func:`serve` or
:class:`WebSocketServerProtocol` to override this logic. For example,
you could reject the handshake if the client doesn't support a
particular subprotocol, rather than accept the handshake without that
subprotocol.
Args:
client_subprotocols: list of subprotocols offered by the client.
server_subprotocols: list of subprotocols available on the server.
Returns:
Optional[Subprotocol]: Selected subprotocol.
Optional[Subprotocol]: Selected subprotocol, if a common subprotocol
was found.
:obj:`None` to continue without a subprotocol.
"""
if self._select_subprotocol is not None:
return self._select_subprotocol(client_subprotocols, server_subprotocols)
Expand Down
116 changes: 78 additions & 38 deletions src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import email.utils
import http
import warnings
from typing import Any, Generator, List, Optional, Sequence, Tuple, cast
from typing import Any, Callable, Generator, List, Optional, Sequence, Tuple, cast

from .datastructures import Headers, MultipleValuesError
from .exceptions import (
Expand Down Expand Up @@ -58,6 +58,10 @@ class ServerProtocol(Protocol):
should be tried.
subprotocols: list of supported subprotocols, in order of decreasing
preference.
select_subprotocol: callback for selecting a subprotocol among
those supported by the client and the server. It has the same
signature as the :meth:`select_subprotocol` method, including a
:class:`ServerProtocol` instance as first argument.
state: initial state of the WebSocket connection.
max_size: maximum size of incoming messages in bytes;
:obj:`None` to disable the limit.
Expand All @@ -73,6 +77,7 @@ def __init__(
origins: Optional[Sequence[Optional[Origin]]] = None,
extensions: Optional[Sequence[ServerExtensionFactory]] = None,
subprotocols: Optional[Sequence[Subprotocol]] = None,
select_subprotocol: Optional[SelectSubprotocol] = None,
state: State = CONNECTING,
max_size: Optional[int] = 2**20,
logger: Optional[LoggerLike] = None,
Expand All @@ -86,6 +91,14 @@ def __init__(
self.origins = origins
self.available_extensions = extensions
self.available_subprotocols = subprotocols
if select_subprotocol is not None:
# Bind select_subprotocol then shadow self.select_subprotocol.
# Use setattr to work around https://github.com/python/mypy/issues/2427.
setattr(
self,
"select_subprotocol",
select_subprotocol.__get__(self, self.__class__),
)

def accept(self, request: Request) -> Response:
"""
Expand All @@ -96,7 +109,7 @@ def accept(self, request: Request) -> Response:
You must send the handshake response with :meth:`send_response`.
You can modify it before sending it, for example to add HTTP headers.
You may modify it before sending it, for example to add HTTP headers.
Args:
request: WebSocket handshake request event received from the client.
Expand Down Expand Up @@ -175,7 +188,8 @@ def accept(self, request: Request) -> Response:
return Response(101, "Switching Protocols", headers)

def process_request(
self, request: Request
self,
request: Request,
) -> Tuple[str, Optional[str], Optional[str]]:
"""
Check a handshake request and negotiate extensions and subprotocol.
Expand Down Expand Up @@ -273,6 +287,7 @@ def process_origin(self, headers: Headers) -> Optional[Origin]:
Optional[Origin]: origin, if it is acceptable.
Raises:
InvalidHandshake: if the Origin header is invalid.
InvalidOrigin: if the origin isn't acceptable.
"""
Expand Down Expand Up @@ -323,7 +338,7 @@ def process_extensions(
HTTP response header and list of accepted extensions.
Raises:
InvalidHandshake: to abort the handshake with an HTTP 400 error.
InvalidHandshake: if the Sec-WebSocket-Extensions header is invalid.
"""
response_header_value: Optional[str] = None
Expand Down Expand Up @@ -383,60 +398,79 @@ def process_subprotocol(self, headers: Headers) -> Optional[Subprotocol]:
also the value of the ``Sec-WebSocket-Protocol`` response header.
Raises:
InvalidHandshake: to abort the handshake with an HTTP 400 error.
InvalidHandshake: if the Sec-WebSocket-Subprotocol header is invalid.
"""
subprotocol: Optional[Subprotocol] = None

header_values = headers.get_all("Sec-WebSocket-Protocol")

if header_values and self.available_subprotocols:

parsed_header_values: List[Subprotocol] = sum(
[parse_subprotocol(header_value) for header_value in header_values], []
)

subprotocol = self.select_subprotocol(
parsed_header_values, self.available_subprotocols
)
subprotocols: Sequence[Subprotocol] = sum(
[
parse_subprotocol(header_value)
for header_value in headers.get_all("Sec-WebSocket-Protocol")
],
[],
)

return subprotocol
return self.select_subprotocol(subprotocols)

def select_subprotocol(
self,
client_subprotocols: Sequence[Subprotocol],
server_subprotocols: Sequence[Subprotocol],
subprotocols: Sequence[Subprotocol],
) -> Optional[Subprotocol]:
"""
Pick a subprotocol among those offered by the client.
If several subprotocols are supported by the client and the server,
the default implementation selects the preferred subprotocols by
giving equal value to the priorities of the client and the server.
If several subprotocols are supported by both the client and the server,
pick the first one in the list declared the server.
If the server doesn't support any subprotocols, continue without a
subprotocol, regardless of what the client offers.
If no common subprotocol is supported by the client and the server, it
proceeds without a subprotocol.
If the server supports at least one subprotocol and the client doesn't
offer any, abort the handshake with an HTTP 400 error.
This is unlikely to be the most useful implementation in practice, as
many servers providing a subprotocol will require that the client uses
that subprotocol.
You provide a ``select_subprotocol`` argument to :class:`ServerProtocol`
to override this logic. For example, you could accept the connection
even if client doesn't offer a subprotocol, rather than reject it.
Here's how to negotiate the ``chat`` subprotocol if the client supports
it and continue without a subprotocol otherwise::
def select_subprotocol(protocol, subprotocols):
if "chat" in subprotocols:
return "chat"
Args:
client_subprotocols: list of subprotocols offered by the client.
server_subprotocols: list of subprotocols available on the server.
subprotocols: list of subprotocols offered by the client.
Returns:
Optional[Subprotocol]: Subprotocol, if a common subprotocol was
found.
Optional[Subprotocol]: Selected subprotocol, if a common subprotocol
was found.
:obj:`None` to continue without a subprotocol.
Raises:
NegotiationError: custom implementations may raise this exception
to abort the handshake with an HTTP 400 error.
"""
subprotocols = set(client_subprotocols) & set(server_subprotocols)
if not subprotocols:
# Server doesn't offer any subprotocols.
if not self.available_subprotocols: # None or empty list
return None
priority = lambda p: (
client_subprotocols.index(p) + server_subprotocols.index(p)

# Server offers at least one subprotocol but client doesn't offer any.
if not subprotocols:
raise NegotiationError("missing subprotocol")

# Server and client both offer subprotocols. Look for a shared one.
proposed_subprotocols = set(subprotocols)
for subprotocol in self.available_subprotocols:
if subprotocol in proposed_subprotocols:
return subprotocol

# No common subprotocol was found.
raise NegotiationError(
"invalid subprotocol; expected one of "
+ ", ".join(self.available_subprotocols)
)
return sorted(subprotocols, key=priority)[0]

def reject(
self,
Expand Down Expand Up @@ -519,6 +553,12 @@ def parse(self) -> Generator[None, None, None]:
yield from super().parse()


SelectSubprotocol = Callable[
[ServerProtocol, Sequence[Subprotocol]],
Optional[Subprotocol],
]


class ServerConnection(ServerProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
warnings.warn(
Expand Down
46 changes: 41 additions & 5 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import unittest.mock

from websockets.datastructures import Headers
from websockets.exceptions import InvalidHeader, InvalidOrigin, InvalidUpgrade
from websockets.exceptions import (
InvalidHeader,
InvalidOrigin,
InvalidUpgrade,
NegotiationError,
)
from websockets.frames import OP_TEXT, Frame
from websockets.http11 import Request, Response
from websockets.protocol import CONNECTING, OPEN
Expand Down Expand Up @@ -544,9 +549,12 @@ def test_no_subprotocol(self):
request = self.make_request()
response = server.accept(request)

self.assertEqual(response.status_code, 101)
self.assertNotIn("Sec-WebSocket-Protocol", response.headers)
self.assertIsNone(server.subprotocol)
self.assertEqual(response.status_code, 400)
with self.assertRaisesRegex(
NegotiationError,
r"missing subprotocol",
):
raise server.handshake_exc

def test_subprotocol(self):
server = ServerProtocol(subprotocols=["chat"])
Expand All @@ -571,8 +579,8 @@ def test_unexpected_subprotocol(self):
def test_multiple_subprotocols(self):
server = ServerProtocol(subprotocols=["superchat", "chat"])
request = self.make_request()
request.headers["Sec-WebSocket-Protocol"] = "superchat"
request.headers["Sec-WebSocket-Protocol"] = "chat"
request.headers["Sec-WebSocket-Protocol"] = "superchat"
response = server.accept(request)

self.assertEqual(response.status_code, 101)
Expand All @@ -595,6 +603,34 @@ def test_unsupported_subprotocol(self):
request.headers["Sec-WebSocket-Protocol"] = "otherchat"
response = server.accept(request)

self.assertEqual(response.status_code, 400)
with self.assertRaisesRegex(
NegotiationError,
r"invalid subprotocol; expected one of superchat, chat",
):
raise server.handshake_exc

@staticmethod
def optional_chat(protocol, subprotocols):
if "chat" in subprotocols:
return "chat"

def test_select_subprotocol(self):
server = ServerProtocol(select_subprotocol=self.optional_chat)
request = self.make_request()
request.headers["Sec-WebSocket-Protocol"] = "chat"
response = server.accept(request)

self.assertEqual(response.status_code, 101)
self.assertEqual(response.headers["Sec-WebSocket-Protocol"], "chat")
self.assertEqual(server.subprotocol, "chat")

def test_select_no_subprotocol(self):
server = ServerProtocol(select_subprotocol=self.optional_chat)
request = self.make_request()
request.headers["Sec-WebSocket-Protocol"] = "otherchat"
response = server.accept(request)

self.assertEqual(response.status_code, 101)
self.assertNotIn("Sec-WebSocket-Protocol", response.headers)
self.assertIsNone(server.subprotocol)
Expand Down

0 comments on commit 23a2d3f

Please sign in to comment.