Permalink
Browse files

test: headers, requests and responses

  • Loading branch information...
Bogdanp committed May 12, 2018
1 parent 6fa54c4 commit 4455e048f59a4917497298bc03d15ebe65958b19
@@ -1,2 +1,5 @@
__pycache__
.mypy_cache
.coverage
.mypy_cache
.pytest_cache
htmlcov
@@ -25,10 +25,15 @@ There is a tag for each part of the series:
## Type-checking
This repo uses Python 3 type annotations which can be type-checked
using [mypy]. Run `pip install mypy` and then `mypy server.py` to
using [mypy]. Run `pip install mypy` and then `mypy scratch` to
type check the code.
## Testing
Run `pip install pytest` and then `py.test`.
## License
web-app-from-scratch is licensed under Apache 2.0. Please see
@@ -2,3 +2,5 @@ isort
flake8
flake8-quotes
mypy
pytest
pytest-cov
No changes.
@@ -23,7 +23,7 @@ def get(self, name, default=None):
def get_int(self, name):
try:
return int(self.get(name))
except ValueError:
except (TypeError, ValueError):
return None
def __iter__(self):
File renamed without changes.
@@ -2,7 +2,7 @@
import socket
import typing
from headers import Headers
from .headers import Headers
class BodyReader(io.IOBase):
@@ -11,7 +11,7 @@ def __init__(self, sock: socket.socket, *, buff: bytes = b"", bufsize: int = 16_
self._buff = buff
self._bufsize = bufsize
def readable(self) -> bool:
def readable(self) -> bool: # pragma: no cover
return True
def read(self, n: int) -> bytes:
@@ -63,7 +63,7 @@ def from_socket(cls, sock: socket.socket) -> "Request":
break
try:
name, _, value = line.decode("ascii").partition(":")
name, value = line.decode("ascii").split(":", 1)
headers.add(name, value.lstrip())
except ValueError:
raise ValueError(f"Malformed header line {line!r}.")
@@ -92,5 +92,5 @@ def iter_lines(sock: socket.socket, bufsize: int = 16_384) -> typing.Generator[b
return buff
yield line
except IndexError:
except (IndexError, ValueError):
break
@@ -3,7 +3,7 @@
import socket
import typing
from headers import Headers
from .headers import Headers
class Response:
@@ -0,0 +1,168 @@
import logging
import mimetypes
import os
import socket
import typing
from queue import Empty, Queue
from threading import Thread
from typing import Callable, List, Tuple
from .request import Request
from .response import Response
LOGGER = logging.getLogger(__name__)
HandlerT = Callable[[Request], Response]
class HTTPWorker(Thread):
def __init__(self, connection_queue: Queue, handlers: List[Tuple[str, HandlerT]]) -> None:
super().__init__(daemon=True)
self.connection_queue = connection_queue
self.handlers = handlers
self.running = False
def stop(self) -> None:
self.running = False
def run(self) -> None:
self.running = True
while self.running:
try:
client_sock, client_addr = self.connection_queue.get(timeout=1)
except Empty:
continue
try:
self.handle_client(client_sock, client_addr)
except Exception:
LOGGER.exception("Unhandled error in handle_client.")
continue
finally:
self.connection_queue.task_done()
def handle_client(self, client_sock: socket.socket, client_addr: typing.Tuple[str, int]) -> None:
with client_sock:
try:
request = Request.from_socket(client_sock)
except Exception:
LOGGER.warning("Failed to parse request.", exc_info=True)
response = Response(status="400 Bad Request", content="Bad Request")
response.send(client_sock)
return
# Force clients to send their request bodies on every
# request rather than making the handlers deal with this.
if "100-continue" in request.headers.get("expect", ""):
response = Response(status="100 Continue")
response.send(client_sock)
for path_prefix, handler in self.handlers:
if request.path.startswith(path_prefix):
try:
request = request._replace(path=request.path[len(path_prefix):])
response = handler(request)
response.send(client_sock)
except Exception as e:
LOGGER.exception("Unexpected error from handler %r.", handler)
response = Response(status="500 Internal Server Error", content="Internal Error")
response.send(client_sock)
finally:
break
else:
response = Response(status="404 Not Found", content="Not Found")
response.send(client_sock)
class HTTPServer:
def __init__(self, host="127.0.0.1", port=9000, worker_count=16) -> None:
self.handlers: List[Tuple[str, HandlerT]] = []
self.host = host
self.port = port
self.worker_count = worker_count
self.worker_backlog = worker_count * 8
self.connection_queue: Queue = Queue(self.worker_backlog)
def mount(self, path_prefix: str, handler: HandlerT) -> None:
"""Mount a request handler at a particular path. Handler
prefixes are tested in the order that they are added so the
first match "wins".
"""
self.handlers.append((path_prefix, handler))
def serve_forever(self) -> None:
workers = []
for _ in range(self.worker_count):
worker = HTTPWorker(self.connection_queue, self.handlers)
worker.start()
workers.append(worker)
with socket.socket() as server_sock:
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
server_sock.bind((self.host, self.port))
server_sock.listen(self.worker_backlog)
LOGGER.info("Listening on %s:%d...", self.host, self.port)
while True:
try:
self.connection_queue.put(server_sock.accept())
except KeyboardInterrupt:
break
for worker in workers:
worker.stop()
for worker in workers:
worker.join(timeout=30)
def serve_static(server_root: str) -> HandlerT:
"""Generate a request handler that serves file off of disk
relative to server_root.
"""
def handler(request: Request) -> Response:
path = request.path
if request.path == "/":
path = "/index.html"
abspath = os.path.normpath(os.path.join(server_root, path.lstrip("/")))
if not abspath.startswith(server_root):
return Response(status="404 Not Found", content="Not Found")
try:
content_type, encoding = mimetypes.guess_type(abspath)
if content_type is None:
content_type = "application/octet-stream"
if encoding is not None:
content_type += f"; charset={encoding}"
body_file = open(abspath, "rb")
response = Response(status="200 OK", body=body_file)
response.headers.add("content-type", content_type)
return response
except FileNotFoundError:
return Response(status="404 Not Found", content="Not Found")
return handler
def wrap_auth(handler: HandlerT) -> HandlerT:
def auth_handler(request: Request) -> Response:
authorization = request.headers.get("authorization", "")
if authorization.startswith("Bearer ") and authorization[len("Bearer "):] == "opensesame":
return handler(request)
return Response(status="403 Forbidden", content="Forbidden!")
return auth_handler
def app(request: Request) -> Response:
return Response(content="Hello!")
server = HTTPServer()
server.mount("/static", serve_static("www"))
server.mount("", wrap_auth(app))
server.serve_forever()
@@ -8,3 +8,7 @@ max-line-length = 120
inline-quotes = "
multiline-quotes = """
[tool:pytest]
testpaths = tests
addopts = --cov scratch --cov-report html
No changes.
@@ -0,0 +1 @@
Hello
@@ -0,0 +1,109 @@
from scratch.headers import Headers
def test_can_add_headers():
# Given that I have an empty Headers object
headers = Headers()
# When I add a header
headers.add("x-a-header", "a value")
# Then I should be able to get the value of that header back
assert headers.get("x-a-header") == "a value"
def test_headers_are_case_insensitive():
# Given that I have a Headers object
headers = Headers()
# And I've added a header that's all caps
headers.add("X-A-HEADER", "a value")
# When I get that header using its lower case name
# Then I should get back its value
assert headers.get("x-a-header") == "a value"
def test_getting_a_missing_header_returns_none():
# Given that I have an empty Headers object
headers = Headers()
# When I get that some header
# Then I should get back None
assert headers.get("x-a-header") is None
def test_can_get_headers_with_fallback():
# Given that I have an empty Headers object
headers = Headers()
# When I get that some header with a fallback value
# Then I should get back that fallback value
assert headers.get("x-a-header", "fallback") is "fallback"
def test_can_get_headers_as_ints():
# Given that I have a Headers object
headers = Headers()
# And I've added a header with a stringy int value
headers.add("content-length", "1024")
# When I get that header as an int
# Then I should get back its int value
assert headers.get_int("content-length") == 1024
def test_can_get_headers_as_ints_with_fallback():
# Given that I have an empty Headers object
headers = Headers()
# When I get some header as an int
# Then I should get back None
assert headers.get_int("content-length") is None
def test_getting_a_header_returns_its_last_value():
# Given that I have a Headers object
headers = Headers()
# And I have added a header multiple times
headers.add("x-some-header", "1")
headers.add("x-some-header", "2")
# When I get the value of that header
# Then its last value should be returned
assert headers.get("x-some-header") == "2"
def test_can_get_all_of_a_headers_values():
# Given that I have a Headers object
headers = Headers()
# And I have added a header multiple times
headers.add("x-some-header", "1")
headers.add("x-some-header", "2")
# When I get all of that header's values
# Then I should get back a list containing each value
assert headers.get_all("x-some-header") == ["1", "2"]
def test_headers_is_iterable():
# Given that I have a Headers object
headers = Headers()
# And I've added a number of headers to it
headers.add("content-type", "application/javascript")
headers.add("content-length", "1024")
headers.add("x-some-header", "1")
headers.add("x-some-header", "2")
# When I iterate over it
# Then I should get back a sequence of name, value pairs
assert sorted(list(headers)) == sorted([
("content-type", "application/javascript"),
("content-length", "1024"),
("x-some-header", "1"),
("x-some-header", "2"),
])
Oops, something went wrong.

0 comments on commit 4455e04

Please sign in to comment.