Skip to content

Commit

Permalink
Added Websocket class and SWITCHING_PROTOCOLS_101
Browse files Browse the repository at this point in the history
  • Loading branch information
michalpokusa committed Jul 13, 2023
1 parent ebb7ca7 commit 1e1ad58
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 2 deletions.
2 changes: 2 additions & 0 deletions adafruit_httpserver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,13 @@
JSONResponse,
Redirect,
SSEResponse,
Websocket,
)
from .route import Route
from .server import Server
from .status import (
Status,
SWITCHING_PROTOCOLS_101,
OK_200,
CREATED_201,
ACCEPTED_202,
Expand Down
285 changes: 283 additions & 2 deletions adafruit_httpserver/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import os
import json
from errno import EAGAIN, ECONNRESET
from binascii import b2a_base64
import hashlib
from errno import EAGAIN, ECONNRESET, ETIMEDOUT, ENOTCONN

from .exceptions import (
BackslashInPathError,
Expand All @@ -25,7 +27,13 @@
)
from .mime_types import MIMETypes
from .request import Request
from .status import Status, OK_200, TEMPORARY_REDIRECT_307, PERMANENT_REDIRECT_308
from .status import (
Status,
SWITCHING_PROTOCOLS_101,
OK_200,
TEMPORARY_REDIRECT_307,
PERMANENT_REDIRECT_308,
)
from .headers import Headers


Expand Down Expand Up @@ -497,3 +505,276 @@ def close(self):
"""
self._send_bytes(self._request.connection, b"event: close\n")
self._close_connection()


class Websocket(Response): # pylint: disable=too-few-public-methods
"""
Specialized version of `Response` class for creating a websocket connection.
Allows two way communication between the client and the server.
Keep in mind, that in order to send and receive messages, the socket must be kept open.
This means that you have to store the response object somewhere, so you can send events
to it and close it later.
**It is very important to close the connection manually, it will not be done automatically.**
Example::
ws = None
@server.route(path, method)
def route_func(request: Request):
# Store the response object somewhere in global scope
global ws
ws = Websocket(request)
return ws
...
# Receive message from client
message = ws.receive()
# Later, when you want to send an event
ws.send_message("Simple message")
# Close the connection
ws.close()
"""

GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
FIN = 0b10000000 # FIN bit indicating the final fragment

# opcodes
CONT = 0 # Continuation frame, TODO: Currently not supported
TEXT = 1 # Frame contains UTF-8 text
BINARY = 2 # Frame contains binary data
CLOSE = 8 # Frame closes the connection
PING = 9 # Frame is a ping, expecting a pong
PONG = 10 # Frame is a pong, in response to a ping

@staticmethod
def _check_request_initiates_handshake(request: Request):
if any(
[
"websocket" not in request.headers.get("Upgrade", "").lower(),
"upgrade" not in request.headers.get("Connection", "").lower(),
"Sec-WebSocket-Key" not in request.headers,
]
):
raise ValueError("Request does not initiate websocket handshake")

@staticmethod
def _process_sec_websocket_key(request: Request) -> str:
key = request.headers.get("Sec-WebSocket-Key")

if key is None:
raise ValueError("Request does not have Sec-WebSocket-Key header")

response_key = hashlib.new('sha1', key.encode())
response_key.update(Websocket.GUID)

return b2a_base64(response_key.digest()).strip().decode()

def __init__( # pylint: disable=too-many-arguments
self,
request: Request,
headers: Union[Headers, Dict[str, str]] = None,
buffer_size: int = 1024,
) -> None:
"""
:param Request request: Request object
:param Headers headers: Headers to be sent with the response.
:param int buffer_size: Size of the buffer used to send and receive messages.
"""
self._check_request_initiates_handshake(request)

sec_accept_key = self._process_sec_websocket_key(request)

super().__init__(
request=request,
status=SWITCHING_PROTOCOLS_101,
headers=headers,
)
self._headers.setdefault("Upgrade", "websocket")
self._headers.setdefault("Connection", "Upgrade")
self._headers.setdefault("Sec-WebSocket-Accept", sec_accept_key)
self._headers.setdefault("Content-Type", None)
self._buffer_size = buffer_size
self.closed = False

request.connection.setblocking(False)


@staticmethod
def _parse_frame_header(header):
fin = header[0] & Websocket.FIN
opcode = header[0] & 0b00001111
has_mask = header[1] & 0b10000000
length = header[1] & 0b01111111

if length == 0b01111110:
length = -2
elif length == 0b01111111:
length = -8

return fin, opcode, has_mask, length

def _read_frame(self):
buffer = bytearray(self._buffer_size)

header_length = self._request.connection.recv_into(buffer, 2)
header_bytes = buffer[:header_length]

fin, opcode, has_mask, length = self._parse_frame_header(header_bytes)

# TODO: Handle continuation frames, currently not supported
if fin != Websocket.FIN and opcode == Websocket.CONT:
return Websocket.CONT, None

payload = bytes()
if fin == Websocket.FIN and opcode == Websocket.CLOSE:
return Websocket.CLOSE, payload

if length < 0:
length = self._request.connection.recv_into(buffer, -length)
length = int.from_bytes(buffer[:length], 'big')

if has_mask:
mask_length = self._request.connection.recv_into(buffer, 4)
mask = buffer[:mask_length]

while 0 < length:
payload_length = self._request.connection.recv_into(buffer, length)
payload += buffer[:min(payload_length, length)]
length -= min(payload_length, length)

if has_mask:
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))

return opcode, payload

def _handle_frame(self, opcode: int, payload: bytes):
# TODO: Handle continuation frames, currently not supported
if opcode == Websocket.CONT:
return None

if opcode == Websocket.CLOSE:
self.close()
return None

if opcode == Websocket.PONG:
return None
elif opcode == Websocket.PING:
self.send_message(payload, Websocket.PONG)
return payload

try:
payload = payload.decode() if opcode == Websocket.TEXT else payload
except UnicodeError as error:
print("Payload UnicodeError: ", error, payload)
pass

return payload

def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]:
"""
Receive a message from the client.
:param bool fail_silently: If True, no error will be raised if the connection is closed.
"""
if self.closed:
if fail_silently:
return None
raise RuntimeError("Websocket connection is closed, cannot receive messages")

try:
opcode, payload = self._read_frame()
frame_data = self._handle_frame(opcode, payload)

return frame_data
except OSError as error:
if error.errno == EAGAIN: # No messages available
return None
if error.errno == ETIMEDOUT: # Connection timed out
return None
if error.errno == ENOTCONN: # Client disconnected without closing connection
self.close()
return None
raise error

@staticmethod
def _prepare_frame(opcode: int, message: bytes) -> bytearray:
frame = bytearray()

frame.append(Websocket.FIN | opcode) # Setting FIN bit

payload_length = len(message)

# Message under 126 bytes, use 1 byte for length
if payload_length < 126:
frame.append(payload_length)

# Message between 126 and 65535 bytes, use 2 bytes for length
elif payload_length < 65536:
frame.append(126)
frame.extend(payload_length.to_bytes(2, 'big'))

# Message over 65535 bytes, use 8 bytes for length
else:
frame.append(127)
frame.extend(payload_length.to_bytes(8, 'big'))

frame.extend(message)
return frame

def send_message(
self,
message: Union[str, bytes],
opcode: int = None,
fail_silently: bool = False
):
"""
Send a message to the client.
:param str message: Message to be sent.
:param int opcode: Opcode of the message. Defaults to TEXT if message is a string and
BINARY for bytes.
:param bool fail_silently: If True, no error will be raised if the connection is closed.
"""
if self.closed:
if fail_silently:
return None
raise RuntimeError("Websocket connection is closed, cannot send message")

determined_opcode = opcode or (
Websocket.TEXT if isinstance(message, str) else Websocket.BINARY
)

if determined_opcode == Websocket.TEXT:
message = message.encode()

frame = self._prepare_frame(determined_opcode, message)

try:
self._send_bytes(self._request.connection, frame)
except BrokenPipeError as error:
if fail_silently:
return None
raise error

def _send(self) -> None:
self._send_headers()

def close(self):
"""
Close the connection.
**Always call this method when you are done sending events.**
"""
if not self.closed:
self.send_message(b'', Websocket.CLOSE, fail_silently=True)
self._close_connection()
self.closed = True
2 changes: 2 additions & 0 deletions adafruit_httpserver/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def __eq__(self, other: "Status"):
return self.code == other.code and self.text == other.text


SWITCHING_PROTOCOLS_101 = Status(101, "Switching Protocols")

OK_200 = Status(200, "OK")

CREATED_201 = Status(201, "Created")
Expand Down

0 comments on commit 1e1ad58

Please sign in to comment.