Skip to content

Commit

Permalink
Extract query params from Components resolve methods
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Dec 21, 2018
1 parent c561882 commit 2551851
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 24 deletions.
13 changes: 11 additions & 2 deletions starlette_api/applications.py
Expand Up @@ -6,6 +6,7 @@
from starlette.middleware.lifespan import LifespanMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.types import ASGIApp

from starlette_api import exceptions
from starlette_api.components import Component
Expand All @@ -24,9 +25,9 @@ def __init__(
components = []

# Initialize injector
self.injector = Injector(components)
self.components = components

self.router = Router()
self.router = Router(components=components)
self.app = self.router
self.exception_middleware = ExceptionMiddleware(self.router, debug=debug)
self.error_middleware = ServerErrorMiddleware(self.exception_middleware, debug=debug)
Expand All @@ -35,5 +36,13 @@ def __init__(
# Add exception handler for API exceptions
self.add_exception_handler(exceptions.HTTPException, self.api_http_exception_handler)

@property
def injector(self):
return Injector(components=self.components)

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
self.components += getattr(app, "components", [])
self.router.mount(path, app=app, name=name)

def api_http_exception_handler(self, request: Request, exc: HTTPException) -> Response:
return JSONResponse(exc.detail, exc.status_code)
2 changes: 1 addition & 1 deletion starlette_api/components/validation.py
Expand Up @@ -53,7 +53,7 @@ def resolve(self, request: http.Request, route: Route, query_params: http.QueryP
)

try:
query_params = validator().load(dict(query_params))
query_params = validator().load(dict(query_params), unknown=marshmallow.EXCLUDE)
except marshmallow.ValidationError as exc:
raise exceptions.ValidationError(detail=exc.normalized_messages())
return ValidatedQueryParams(query_params)
Expand Down
56 changes: 43 additions & 13 deletions starlette_api/routing.py
Expand Up @@ -12,6 +12,7 @@

import marshmallow
from starlette_api import http
from starlette_api.components import Component

__all__ = ["Route", "WebSocketRoute", "Router"]

Expand All @@ -35,7 +36,7 @@ class Field:


class FieldsMixin:
def _get_fields(self) -> typing.Tuple[MethodsMap, MethodsMap, typing.Dict[str, Field]]:
def _get_fields(self, router: "Router") -> typing.Tuple[MethodsMap, MethodsMap, typing.Dict[str, Field]]:
query_fields: MethodsMap = {}
path_fields: MethodsMap = {}
body_field: typing.Dict[str, Field] = {}
Expand All @@ -49,18 +50,36 @@ def _get_fields(self) -> typing.Tuple[MethodsMap, MethodsMap, typing.Dict[str, F
methods = [("GET", self.endpoint)]

for method, handler in methods:
query_fields[method], path_fields[method], body_field[method] = self._get_fields_from_handler(handler)
query_fields[method], path_fields[method], body_field[method] = self._get_fields_from_handler(
handler, router
)

return query_fields, path_fields, body_field

def _get_fields_from_handler(self, handler: typing.Callable) -> typing.Tuple[FieldsMap, FieldsMap, Field]:
def _get_parameters_from_handler(
self, handler: typing.Callable, router: "Router"
) -> typing.Dict[str, inspect.Parameter]:
parameters = {}

for name, parameter in inspect.signature(handler).parameters.items():
for component in router.components:
if component.can_handle_parameter(parameter):
parameters.update(self._get_parameters_from_handler(component.resolve, router))
break
else:
parameters[name] = parameter

return parameters

def _get_fields_from_handler(
self, handler: typing.Callable, router: "Router"
) -> typing.Tuple[FieldsMap, FieldsMap, Field]:
query_fields: FieldsMap = {}
path_fields: FieldsMap = {}
body_field: Field = None

# Iterate over all params
parameters = inspect.signature(handler).parameters
for name, param in parameters.items():
for name, param in self._get_parameters_from_handler(handler, router).items():
if name in ("self", "cls"):
continue

Expand Down Expand Up @@ -98,14 +117,14 @@ def _get_fields_from_handler(self, handler: typing.Callable) -> typing.Tuple[Fie


class Route(starlette.routing.Route, FieldsMixin):
def __init__(self, path: str, endpoint: typing.Callable, *args, **kwargs):
def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args, **kwargs):
super().__init__(path, endpoint=endpoint, **kwargs)

# Replace function with another wrapper that uses the injector
if inspect.isfunction(endpoint):
self.app = self.endpoint_wrapper(endpoint)

self.query_fields, self.path_fields, self.body_field = self._get_fields()
self.query_fields, self.path_fields, self.body_field = self._get_fields(router)

def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp:
"""
Expand Down Expand Up @@ -157,14 +176,14 @@ async def awaitable(receive: Receive, send: Send) -> None:


class WebSocketRoute(starlette.routing.WebSocketRoute, FieldsMixin):
def __init__(self, path: str, endpoint: typing.Callable, *args, **kwargs):
def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args, **kwargs):
super().__init__(path, endpoint=endpoint, **kwargs)

# Replace function with another wrapper that uses the injector
if inspect.isfunction(endpoint):
self.app = self.endpoint_wrapper(endpoint)

self.query_fields, self.path_fields, self.body_field = self._get_fields()
self.query_fields, self.path_fields, self.body_field = self._get_fields(router)

def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp:
"""
Expand Down Expand Up @@ -199,6 +218,10 @@ async def awaitable(receive: Receive, send: Send) -> None:


class Router(starlette.routing.Router):
def __init__(self, components: typing.Optional[typing.List[Component]] = None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.components = components

def add_route(
self,
path: str,
Expand All @@ -208,9 +231,19 @@ def add_route(
include_in_schema: bool = True,
):
self.routes.append(
Route(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema)
Route(path, endpoint=endpoint, methods=methods, name=name, include_in_schema=include_in_schema, router=self)
)

def add_websocket_route(self, path: str, endpoint: typing.Callable, name: str = None):
self.routes.append(WebSocketRoute(path, endpoint=endpoint, name=name, router=self))

def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
if isinstance(app, Router):
app.components = self.components

route = Mount(path, app=app, name=name)
self.routes.append(route)

def route(
self, path: str, methods: typing.List[str] = None, name: str = None, include_in_schema: bool = True
) -> typing.Callable:
Expand All @@ -220,9 +253,6 @@ def decorator(func: typing.Callable) -> typing.Callable:

return decorator

def add_websocket_route(self, path: str, endpoint: typing.Callable, name: str = None):
self.routes.append(WebSocketRoute(path, endpoint=endpoint, name=name))

def websocket_route(self, path: str, name: str = None) -> typing.Callable:
def decorator(func: typing.Callable) -> typing.Callable:
self.add_websocket_route(path, func, name=name)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_custom_components.py
Expand Up @@ -80,17 +80,17 @@ def test_unknown_component(self, client):
client.get("/unknown")

def test_unhandled_component(self):
app_ = Starlette(components=[UnhandledComponent()])

@app_.route("/")
def foo(unknown: Unknown):
return JSONResponse({"foo": "bar"})

client = TestClient(app_)

with pytest.raises(
ConfigurationError,
match=r'Component "UnhandledComponent" must include a return annotation on the `resolve\(\)` method, '
"or override `can_handle_parameter`",
):
app_ = Starlette(components=[UnhandledComponent()])

@app_.route("/")
def foo(unknown: Unknown):
return JSONResponse({"foo": "bar"})

client = TestClient(app_)

client.get("/")

0 comments on commit 2551851

Please sign in to comment.