Skip to content

Commit

Permalink
Merge pull request #84 from michalpokusa/9.0.0-compatibility-and-bett…
Browse files Browse the repository at this point in the history
…er-typing

9.0.0 compatibility and better typing for sockets
  • Loading branch information
FoamyGuy committed Feb 24, 2024
2 parents d8f9a72 + b00f70f commit dc9f83c
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 56 deletions.
63 changes: 61 additions & 2 deletions adafruit_httpserver/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,70 @@
"""

try:
from typing import List, Dict, Union, Any
from typing import List, Tuple, Dict, Union, Any
except ImportError:
pass


class _ISocket: # pylint: disable=missing-function-docstring,no-self-use,unused-argument
"""A class for typing necessary methods for a socket object."""

def accept(self) -> Tuple["_ISocket", Tuple[str, int]]:
...

def bind(self, address: Tuple[str, int]) -> None:
...

def setblocking(self, flag: bool) -> None:
...

def settimeout(self, value: "Union[float, None]") -> None:
...

def setsockopt(self, level: int, optname: int, value: int) -> None:
...

def listen(self, backlog: int) -> None:
...

def send(self, data: bytes) -> int:
...

def recv_into(self, buffer: memoryview, nbytes: int) -> int:
...

def close(self) -> None:
...


class _ISocketPool: # pylint: disable=missing-function-docstring,no-self-use,unused-argument
"""A class to typing necessary methods and properties for a socket pool object."""

AF_INET: int
SO_REUSEADDR: int
SOCK_STREAM: int
SOL_SOCKET: int

def socket( # pylint: disable=redefined-builtin
self,
family: int = ...,
type: int = ...,
proto: int = ...,
) -> _ISocket:
...

def getaddrinfo( # pylint: disable=redefined-builtin,too-many-arguments
self,
host: str,
port: int,
family: int = ...,
type: int = ...,
proto: int = ...,
flags: int = ...,
) -> Tuple[int, int, int, str, Tuple[str, int]]:
...


class _IFieldStorage:
"""Interface with shared methods for QueryParams, FormData and Headers."""

Expand Down Expand Up @@ -62,7 +121,7 @@ def __contains__(self, key: str) -> bool:
return key in self._storage

def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self._storage)})"
return f"<{self.__class__.__name__} {repr(self._storage)}>"


def _encode_html_entities(value: Union[str, None]) -> Union[str, None]:
Expand Down
24 changes: 14 additions & 10 deletions adafruit_httpserver/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

try:
from typing import List, Dict, Tuple, Union, Any, TYPE_CHECKING
from socket import socket
from socketpool import SocketPool

if TYPE_CHECKING:
from .server import Server
Expand All @@ -20,7 +18,7 @@
import json

from .headers import Headers
from .interfaces import _IFieldStorage, _IXSSSafeFieldStorage
from .interfaces import _ISocket, _IFieldStorage, _IXSSSafeFieldStorage
from .methods import POST, PUT, PATCH, DELETE


Expand Down Expand Up @@ -127,11 +125,11 @@ def size(self) -> int:

def __repr__(self) -> str:
filename, content_type, size = (
repr(self.filename),
repr(self.content_type),
repr(self.size),
self.filename,
self.content_type,
self.size,
)
return f"{self.__class__.__name__}({filename=}, {content_type=}, {size=})"
return f"<{self.__class__.__name__} {filename=}, {content_type=}, {size=}>"


class Files(_IFieldStorage):
Expand Down Expand Up @@ -260,7 +258,9 @@ def get_list(self, field_name: str, *, safe=True) -> List[Union[str, bytes]]:

def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}({repr(self._storage)}, files={repr(self.files._storage)})"
return (
f"<{class_name} {repr(self._storage)}, files={repr(self.files._storage)}>"
)


class Request: # pylint: disable=too-many-instance-attributes
Expand All @@ -274,7 +274,7 @@ class Request: # pylint: disable=too-many-instance-attributes
Server object that received the request.
"""

connection: Union["SocketPool.Socket", "socket.socket"]
connection: _ISocket
"""
Socket object used to send and receive data on the connection.
"""
Expand Down Expand Up @@ -325,7 +325,7 @@ class Request: # pylint: disable=too-many-instance-attributes
def __init__(
self,
server: "Server",
connection: Union["SocketPool.Socket", "socket.socket"],
connection: _ISocket,
client_address: Tuple[str, int],
raw_request: bytes = None,
) -> None:
Expand Down Expand Up @@ -481,6 +481,10 @@ def _parse_request_header(

return method, path, query_params, http_version, headers

def __repr__(self) -> str:
path = self.path + (f"?{self.query_params}" if self.query_params else "")
return f'<{self.__class__.__name__} "{self.method} {path}">'


def _debug_unsupported_form_content_type(content_type: str) -> None:
"""Warns when an unsupported form content type is used."""
Expand Down
11 changes: 7 additions & 4 deletions adafruit_httpserver/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

try:
from typing import Optional, Dict, Union, Tuple, Generator, Any
from socket import socket
from socketpool import SocketPool
except ImportError:
pass

Expand Down Expand Up @@ -47,6 +45,7 @@
PERMANENT_REDIRECT_308,
)
from .headers import Headers
from .interfaces import _ISocket


class Response: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -132,7 +131,7 @@ def _send(self) -> None:

def _send_bytes(
self,
conn: Union["SocketPool.Socket", "socket.socket"],
conn: _ISocket,
buffer: Union[bytes, bytearray, memoryview],
):
bytes_sent: int = 0
Expand Down Expand Up @@ -217,6 +216,10 @@ def __init__( # pylint: disable=too-many-arguments
)
self._filename = filename + "index.html" if filename.endswith("/") else filename
self._root_path = root_path or self._request.server.root_path

if self._root_path is None:
raise ValueError("root_path must be provided in Server or in FileResponse")

self._full_file_path = self._combine_path(self._root_path, self._filename)
self._content_type = content_type or MIMETypes.get_for_filename(self._filename)
self._file_length = self._get_file_length(self._full_file_path)
Expand Down Expand Up @@ -708,7 +711,7 @@ def _read_frame(self):
length -= min(payload_length, length)

if has_mask:
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
payload = bytes(byte ^ mask[idx % 4] for idx, byte in enumerate(payload))

return opcode, payload

Expand Down
8 changes: 4 additions & 4 deletions adafruit_httpserver/route.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def matches(
return True, dict(zip(self.parameters_names, url_parameters_values))

def __repr__(self) -> str:
path = repr(self.path)
methods = repr(self.methods)
handler = repr(self.handler)
path = self.path
methods = self.methods
handler = self.handler

return f"Route({path=}, {methods=}, {handler=})"
return f"<Route {path=}, {methods=}, {handler=}>"


def as_route(
Expand Down
61 changes: 33 additions & 28 deletions adafruit_httpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@
"""

try:
from typing import Callable, Protocol, Union, List, Tuple, Dict, Iterable
from socket import socket
from socketpool import SocketPool
from typing import Callable, Union, List, Tuple, Dict, Iterable
except ImportError:
pass

Expand All @@ -28,6 +26,7 @@
ServingFilesDisabledError,
)
from .headers import Headers
from .interfaces import _ISocketPool, _ISocket
from .methods import GET, HEAD
from .request import Request
from .response import Response, FileResponse
Expand All @@ -54,7 +53,7 @@ class Server: # pylint: disable=too-many-instance-attributes
"""Root directory to serve files from. ``None`` if serving files is disabled."""

def __init__(
self, socket_source: Protocol, root_path: str = None, *, debug: bool = False
self, socket_source: _ISocketPool, root_path: str = None, *, debug: bool = False
) -> None:
"""Create a server, and get it ready to run.
Expand Down Expand Up @@ -172,7 +171,7 @@ def _verify_can_start(self, host: str, port: int) -> None:
raise RuntimeError(f"Cannot start server on {host}:{port}") from error

def serve_forever(
self, host: str, port: int = 80, *, poll_interval: float = 0.1
self, host: str = "0.0.0.0", port: int = 5000, *, poll_interval: float = 0.1
) -> None:
"""
Wait for HTTP requests at the given host and port. Does not return.
Expand All @@ -195,15 +194,25 @@ def serve_forever(
except Exception: # pylint: disable=broad-except
pass # Ignore exceptions in handler function

def _set_socket_level_to_reuse_address(self) -> None:
"""
Only for CPython, prevents "Address already in use" error when restarting the server.
"""
self._sock.setsockopt(
self._socket_source.SOL_SOCKET, self._socket_source.SO_REUSEADDR, 1
)
@staticmethod
def _create_server_socket(
socket_source: _ISocketPool,
host: str,
port: int,
) -> _ISocket:
sock = socket_source.socket(socket_source.AF_INET, socket_source.SOCK_STREAM)

# TODO: Temporary backwards compatibility, remove after CircuitPython 9.0.0 release
if implementation.version >= (9,) or implementation.name != "circuitpython":
sock.setsockopt(socket_source.SOL_SOCKET, socket_source.SO_REUSEADDR, 1)

sock.bind((host, port))
sock.listen(10)
sock.setblocking(False) # Non-blocking socket

return sock

def start(self, host: str, port: int = 80) -> None:
def start(self, host: str = "0.0.0.0", port: int = 5000) -> None:
"""
Start the HTTP server at the given host and port. Requires calling
``.poll()`` in a while loop to handle incoming requests.
Expand All @@ -216,16 +225,7 @@ def start(self, host: str, port: int = 80) -> None:
self.host, self.port = host, port

self.stopped = False
self._sock = self._socket_source.socket(
self._socket_source.AF_INET, self._socket_source.SOCK_STREAM
)

if implementation.name != "circuitpython":
self._set_socket_level_to_reuse_address()

self._sock.bind((host, port))
self._sock.listen(10)
self._sock.setblocking(False) # Non-blocking socket
self._sock = self._create_server_socket(self._socket_source, host, port)

if self.debug:
_debug_started_server(self)
Expand All @@ -244,9 +244,7 @@ def stop(self) -> None:
if self.debug:
_debug_stopped_server(self)

def _receive_header_bytes(
self, sock: Union["SocketPool.Socket", "socket.socket"]
) -> bytes:
def _receive_header_bytes(self, sock: _ISocket) -> bytes:
"""Receive bytes until a empty line is received."""
received_bytes = bytes()
while b"\r\n\r\n" not in received_bytes:
Expand All @@ -263,7 +261,7 @@ def _receive_header_bytes(

def _receive_body_bytes(
self,
sock: Union["SocketPool.Socket", "socket.socket"],
sock: _ISocket,
received_body_bytes: bytes,
content_length: int,
) -> bytes:
Expand All @@ -282,7 +280,7 @@ def _receive_body_bytes(

def _receive_request(
self,
sock: Union["SocketPool.Socket", "socket.socket"],
sock: _ISocket,
client_address: Tuple[str, int],
) -> Request:
"""Receive bytes from socket until the whole request is received."""
Expand Down Expand Up @@ -530,6 +528,13 @@ def socket_timeout(self, value: int) -> None:
else:
raise ValueError("Server.socket_timeout must be a positive numeric value.")

def __repr__(self) -> str:
host = self.host
port = self.port
root_path = self.root_path

return f"<Server {host=}, {port=}, {root_path=}>"


def _debug_warning_exposed_files(root_path: str):
"""Warns about exposing all files on the device."""
Expand Down
11 changes: 7 additions & 4 deletions adafruit_httpserver/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ def __init__(self, code: int, text: str):
self.code = code
self.text = text

def __repr__(self):
return f'Status({self.code}, "{self.text}")'
def __eq__(self, other: "Status"):
return self.code == other.code and self.text == other.text

def __str__(self):
return f"{self.code} {self.text}"

def __eq__(self, other: "Status"):
return self.code == other.code and self.text == other.text
def __repr__(self):
code = self.code
text = self.text

return f'<Status {code}, "{text}">'


SWITCHING_PROTOCOLS_101 = Status(101, "Switching Protocols")
Expand Down

0 comments on commit dc9f83c

Please sign in to comment.