Skip to content

Commit

Permalink
Added support for middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed May 3, 2022
1 parent 52a999e commit 2ba60ce
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 7 deletions.
44 changes: 40 additions & 4 deletions src/asphalt/web/asgi3.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from asyncio import create_task, sleep
from collections.abc import Callable, Sequence
from dataclasses import dataclass
from inspect import isfunction
from typing import Any, Generic, TypeVar
Expand Down Expand Up @@ -58,6 +59,8 @@ class ASGIComponent(ContainerComponent, Generic[T_Application]):
:type app: asgiref.typing.ASGI3Application | None
:param host: the IP address to bind to
:param port: the port to bind to
:param middlewares: list of callables or dicts to be added as middleware using
:meth:`add_middleware`
"""

def __init__(
Expand All @@ -67,19 +70,52 @@ def __init__(
app: T_Application,
host: str = "127.0.0.1",
port: int = 8000,
middlewares: Sequence[Callable[..., ASGI3Application] | dict[str, Any]] = (),
) -> None:
super().__init__(components)
self.app: T_Application = resolve_reference(app)
self.original_app = app
self.host = host
self.port = port

self.add_middleware(self.wrap_in_middleware)
for middleware in middlewares:
self.add_middleware(middleware)

def wrap_in_middleware(self, app: T_Application) -> ASGI3Application:
return AsphaltMiddleware(app)

def add_middleware(
self, middleware: Callable[..., ASGI3Application] | dict[str, Any]
) -> None:
"""
Add middleware to the application.
:param middleware: either a callable that takes the application object and
returns an ASGI 3.0 application, or a dictionary containing a reference to
such a callable. This dictionary must contain the key ``type`` which is a
non-async callable (or a module:varname reference to one) and which will be
called with the application object as the first positional argument and the
rest of the keys in the dict as keyword arguments.
"""
if isinstance(middleware, dict):
type_ = resolve_reference(middleware.pop("type", None))
if not callable(type_):
raise TypeError(f"Middleware ({type_}) is not callable")

self.app = type_(self.app, **middleware)
elif callable(middleware):
self.app = middleware(self.app)
else:
raise TypeError(
f"middleware must be either a callable or a dict, not {middleware!r}"
)

@context_teardown
async def start(self, ctx: Context):
config = Config(
app=self.wrap_in_middleware(self.app),
app=self.app,
host=self.host,
port=self.port,
use_colors=False,
Expand All @@ -88,10 +124,10 @@ async def start(self, ctx: Context):
)

types = [ASGI3Application]
if not isfunction(self.app):
types.append(type(self.app))
if not isfunction(self.original_app):
types.append(type(self.original_app))

ctx.add_resource(self.app, types=types)
ctx.add_resource(self.original_app, types=types)
await super().start(ctx)

server = uvicorn.Server(config)
Expand Down
2 changes: 2 additions & 0 deletions src/asphalt/web/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class DjangoComponent(ASGIComponent[ASGIHandler]):
:param django.core.handlers.asgi.ASGIHandler app: the Django ASGI handler object
:param host: the IP address to bind to
:param port: the port to bind to
:param middlewares: list of callables or dicts to be added as middleware using
:meth:`add_middleware`
"""

def wrap_in_middleware(self, app: ASGIHandler) -> ASGI3Application:
Expand Down
10 changes: 9 additions & 1 deletion src/asphalt/web/fastapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from inspect import Parameter, Signature, signature
from typing import Any
Expand Down Expand Up @@ -55,6 +56,8 @@ class FastAPIComponent(ASGIComponent[FastAPI]):
(default: the value of
`__debug__ <https://docs.python.org/3/library/constants.html#debug__>`_;
ignored if an application object is explicitly passed in)
:param middlewares: list of callables or dicts to be added as middleware using
:meth:`add_middleware`
"""

def __init__(
Expand All @@ -65,10 +68,15 @@ def __init__(
host: str = "127.0.0.1",
port: int = 8000,
debug: bool | None = None,
middlewares: Sequence[Callable[..., ASGI3Application] | dict[str, Any]] = (),
) -> None:
debug = debug if isinstance(debug, bool) else __debug__
super().__init__(
components, app=app or FastAPI(debug=debug), host=host, port=port
components,
app=app or FastAPI(debug=debug),
host=host,
port=port,
middlewares=middlewares,
)

def wrap_in_middleware(self, app: FastAPI) -> ASGI3Application:
Expand Down
10 changes: 9 additions & 1 deletion src/asphalt/web/starlette.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from typing import Any

from asgiref.typing import ASGI3Application, HTTPScope, WebSocketScope
Expand Down Expand Up @@ -41,6 +42,8 @@ class StarletteComponent(ASGIComponent[Starlette]):
(default: the value of
`__debug__ <https://docs.python.org/3/library/constants.html#debug__>`_;
ignored if an application object is explicitly passed in)
:param middlewares: list of callables or dicts to be added as middleware using
:meth:`add_middleware`
"""

def __init__(
Expand All @@ -51,10 +54,15 @@ def __init__(
host: str = "127.0.0.1",
port: int = 8000,
debug: bool | None = None,
middlewares: Sequence[Callable[..., ASGI3Application] | dict[str, Any]] = (),
) -> None:
debug = debug if isinstance(debug, bool) else __debug__
super().__init__(
components, app=app or Starlette(debug=debug), host=host, port=port
components,
app=app or Starlette(debug=debug),
host=host,
port=port,
middlewares=middlewares,
)

def wrap_in_middleware(self, app: Starlette) -> ASGI3Application:
Expand Down
62 changes: 61 additions & 1 deletion tests/test_asgi3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import json
from typing import cast
from collections.abc import Callable, Sequence
from typing import Any, cast
from urllib.parse import parse_qs

import pytest
Expand All @@ -8,7 +11,9 @@
ASGI3Application,
ASGIReceiveCallable,
ASGISendCallable,
ASGISendEvent,
HTTPScope,
Scope,
WebSocketScope,
)
from asphalt.core import Context, current_context, inject, resource
Expand Down Expand Up @@ -74,6 +79,24 @@ async def application(
)


class TextReplacerMiddleware:
def __init__(self, app: ASGI3Application, text: str, replacement: str):
self.app = app
self.text = text.encode()
self.replacement = replacement.encode()

async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None:
async def wrapped_send(event: ASGISendEvent) -> None:
if event["type"] == "http.response.body":
event["body"] = event["body"].replace(self.text, self.replacement)

await send(event)

await self.app(scope, receive, wrapped_send)


@pytest.mark.asyncio
async def test_asgi_http(unused_tcp_port: int):
async with Context() as ctx, AsyncClient() as http:
Expand Down Expand Up @@ -115,3 +138,40 @@ async def test_asgi_ws(unused_tcp_port: int):
"my resource": "foo",
"another resource": "bar",
}


@pytest.mark.parametrize("method", ["direct", "dict"])
@pytest.mark.asyncio
async def test_asgi_middleware(unused_tcp_port: int, method: str):
middlewares: Sequence[Callable[..., ASGI3Application] | dict[str, Any]]
if method == "direct":
middlewares = [lambda app: TextReplacerMiddleware(app, "World", "Middleware")]
else:
middlewares = [
{
"type": f"{__name__}:TextReplacerMiddleware",
"text": "World",
"replacement": "Middleware",
}
]

async with Context() as ctx, AsyncClient() as http:
ctx.add_resource("foo")
ctx.add_resource("bar", name="another")
await ASGIComponent(
app=application, port=unused_tcp_port, middlewares=middlewares
).start(ctx)

# Ensure that the application got added as a resource
ctx.require_resource(ASGI3Application)

# Ensure that the application responds correctly to an HTTP request
response = await http.get(
f"http://127.0.0.1:{unused_tcp_port}", params={"param": "Hello World"}
)
response.raise_for_status()
assert response.json() == {
"message": "Hello Middleware",
"my resource": "foo",
"another resource": "bar",
}
4 changes: 4 additions & 0 deletions tests/test_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from asphalt.core import Component, Context, inject, require_resource, resource
from httpx import AsyncClient
from starlette.applications import Starlette
from starlette.middleware.trustedhost import TrustedHostMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.websockets import WebSocket
Expand All @@ -30,6 +31,9 @@ async def root(
)


TrustedHostMiddleware


@inject
async def ws_root(
websocket: WebSocket,
Expand Down

0 comments on commit 2ba60ce

Please sign in to comment.