Skip to content

Commit

Permalink
Initial tag support (#30)
Browse files Browse the repository at this point in the history
* Initial tag support

* Docs
  • Loading branch information
Tinche committed Apr 24, 2023
1 parent a2a3874 commit 5e49a51
Show file tree
Hide file tree
Showing 16 changed files with 141 additions and 64 deletions.
23 changes: 23 additions & 0 deletions docs/openapi.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,26 @@ spec = app.make_openapi_spec()
# Serve the schema at /openapi.json by default
app.serve_openapi()
```

Additionally, _uapi_ also supports serving several OpenAPI documentation viewers:

```python
app.serve_swaggerui()
app.serve_redoc()
app.serve_elements()
```

The documentation viewer will be available at its default URL.

## Endpoint Tags

OpenAPI supports grouping endpoints by tags.
You can specify tags for each handler when registering it:

```python
@app.get("/{article_id}", tags=["articles"])
async def get_article(article_id: str) -> str:
return "Getting the article"
```

Depending on the OpenAPI visualization framework used, endpoints with tags are usually displayed grouped under the tag.
2 changes: 1 addition & 1 deletion src/uapi/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def request_bytes(_request: FrameworkRequest) -> bytes:
def to_framework_routes(self) -> RouteTableDef:
r = RouteTableDef()

for (method, path), (handler, name) in self._route_map.items():
for (method, path), (handler, name, _) in self._route_map.items():
ra = make_return_adapter(
signature(handler, eval_str=True).return_annotation,
FrameworkResponse,
Expand Down
55 changes: 32 additions & 23 deletions src/uapi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .openapi import converter as openapi_converter
from .openapi import default_summary_transformer, make_openapi_spec
from .status import Ok
from .types import Method, PathParamParser
from .types import Method, PathParamParser, RouteName, RouteTags


def make_base_incanter() -> Incanter:
Expand All @@ -30,7 +30,9 @@ class OpenAPISecuritySpec:
class App:
converter: Converter = Factory(make_converter)
base_incant: Incanter = Factory(make_base_incanter)
_route_map: dict[tuple[Method, str], tuple[Callable, str]] = Factory(dict)
_route_map: dict[
tuple[Method, str], tuple[Callable, RouteName, RouteTags]
] = Factory(dict)
_openapi_security: list[OpenAPISecuritySpec] = Factory(list)
_path_param_parser: ClassVar[PathParamParser] = lambda p: (p, [])
_framework_req_cls: ClassVar[type] = NoneType
Expand All @@ -40,55 +42,59 @@ def route(
self,
path: str,
handler,
name: str | None = None,
methods: Sequence[Method] = ["GET"],
name: str | None = None,
tags: RouteTags = (),
):
"""Register routes. This is not a decorator."""
"""Register routes. This is not a decorator.
:param tags: The OpenAPI tags to apply.
"""
if name is None:
name = handler.__name__
for method in methods:
self._route_map[(method, path)] = (handler, name)
self._route_map[(method, path)] = (handler, name, tags)
return handler

def get(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["GET"])
def get(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["GET"], tags=tags)

def post(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["POST"])
def post(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["POST"], tags=tags)

def put(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["PUT"])
def put(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["PUT"], tags=tags)

def patch(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["PATCH"])
def patch(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["PATCH"], tags=tags)

def delete(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["DELETE"])
def delete(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["DELETE"], tags=tags)

def head(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["HEAD"])
def head(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["HEAD"], tags=tags)

def options(self, path: str, name: str | None = None):
return partial(self.route, path, name=name, methods=["OPTIONS"])
def options(self, path: str, name: str | None = None, tags: RouteTags = ()):
return partial(self.route, path, name=name, methods=["OPTIONS"], tags=tags)

def route_app(
self, app: "App", prefix: str | None = None, name_prefix: str | None = None
) -> None:
"""Register all routes from a different app under an optional path prefix."""
if not isinstance(self, type(app)):
raise Exception("Incompatible apps.")
for (method, path), (handler, name) in app._route_map.items():
for (method, path), (handler, name, tags) in app._route_map.items():
if name_prefix is not None:
if name is None:
name = handler.__name__
name = f"{name_prefix}.{name}"
self._route_map[(method, (prefix or "") + path)] = (handler, name)
self._route_map[(method, (prefix or "") + path)] = (handler, name, tags)

def make_openapi_spec(
self,
title: str = "Server",
version: str = "1.0",
exclude: set[str] = set(),
exclude: set[RouteName] = set(),
summary_transformer: SummaryTransformer = default_summary_transformer,
) -> OpenAPI:
"""
Expand All @@ -100,7 +106,7 @@ def make_openapi_spec(
"""
# We need to prepare the handlers to get the correct signature.
route_map = {
k: (self.base_incant.prepare(v[0]), v[1])
k: (self.base_incant.prepare(v[0]), v[1], v[2])
for k, v in self._route_map.items()
if v[1] not in exclude
}
Expand Down Expand Up @@ -140,6 +146,7 @@ def openapi_handler() -> Ok[bytes]:
self.route(path, openapi_handler)

def serve_swaggerui(self, path: str = "/swaggerui"):
"""Start serving the Swagger UI at the given path."""
from .openapi_ui import swaggerui

def swaggerui_handler() -> Ok[str]:
Expand All @@ -148,6 +155,7 @@ def swaggerui_handler() -> Ok[str]:
self.route(path, swaggerui_handler)

def serve_redoc(self, path: str = "/redoc"):
"""Start serving the ReDoc UI at the given path."""
from .openapi_ui import redoc

def redoc_handler() -> Ok[str]:
Expand All @@ -158,6 +166,7 @@ def redoc_handler() -> Ok[str]:
def serve_elements(
self, path: str = "/elements", openapi_path: str = "/openapi.json"
):
"""Start serving the OpenAPI Elements UI at the given path."""
from .openapi_ui import elements as elements_html

fixed_path = elements_html.replace("$OPENAPIURL", openapi_path)
Expand Down
12 changes: 7 additions & 5 deletions src/uapi/django.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
)
from .responses import dict_to_headers, identity, make_return_adapter
from .status import BaseResponse, get_status_code
from .types import Method, PathParamParser
from .types import Method, PathParamParser, RouteName, RouteTags

C = TypeVar("C")

Expand Down Expand Up @@ -120,15 +120,17 @@ class DjangoApp(BaseApp):
def to_urlpatterns(self) -> list[URLPattern]:
res = []

by_path_by_method: dict[str, dict[Method, tuple[Callable, str]]] = {}
for (method, path), (handler, name) in self._route_map.items():
by_path_by_method.setdefault(path, {})[method] = (handler, name)
by_path_by_method: dict[
str, dict[Method, tuple[Callable, RouteName, RouteTags]]
] = {}
for (method, path), v in self._route_map.items():
by_path_by_method.setdefault(path, {})[method] = v

for path, methods_and_handlers in by_path_by_method.items():
# Django does not strip the prefix slash, so we do it for it.
path = path.removeprefix("/")
per_method_adapted = {}
for method, (handler, name) in methods_and_handlers.items():
for method, (handler, name, _) in methods_and_handlers.items():
ra = make_return_adapter(
signature(handler, eval_str=True).return_annotation,
FrameworkResponse,
Expand Down
2 changes: 1 addition & 1 deletion src/uapi/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FlaskApp(BaseApp):
def to_framework_app(self, import_name: str) -> Flask:
f = Flask(import_name)

for (method, path), (handler, name) in self._route_map.items():
for (method, path), (handler, name, _) in self._route_map.items():
ra = make_return_adapter(
signature(handler, eval_str=True).return_annotation,
FrameworkResponse,
Expand Down
25 changes: 17 additions & 8 deletions src/uapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from .requests import get_cookie_name, maybe_header_type, maybe_req_body_attrs
from .responses import get_status_code_results
from .status import BaseResponse
from .types import Method, PathParamParser, Routes, is_subclass
from .types import Method, PathParamParser, RouteName, Routes, RouteTags, is_subclass

converter = make_converter(omit_if_default=True)

# MediaTypeNames are like `application/json`.
MediaTypeName = str

SummaryTransformer = Callable[[Callable, str], str | None]
SummaryTransformer: TypeAlias = Callable[[Callable, str], str | None]


def default_summary_transformer(handler: Callable, name: str) -> str:
Expand Down Expand Up @@ -133,6 +133,7 @@ class Operation:
requestBody: RequestBody | None = None
security: list[SecurityRequirement] = Factory(list)
summary: str | None = None
tags: list[str] = Factory(list)

get: Operation | None = None
post: Operation | None = None
Expand Down Expand Up @@ -169,6 +170,7 @@ def build_operation(
framework_resp_cls: type | None,
security_schemas: Mapping[str, ApiKeySecurityScheme],
summary_transformer: SummaryTransformer,
tags: list[str],
) -> OpenAPI.PathItem.Operation:
request_bodies = {}
request_body_required = False
Expand Down Expand Up @@ -306,13 +308,13 @@ def build_operation(
security.append({sec_name: []})

return OpenAPI.PathItem.Operation(
responses, params, req_body, security, summary_transformer(handler, name)
responses, params, req_body, security, summary_transformer(handler, name), tags
)


def build_pathitem(
path: str,
path_routes: dict[Method, tuple[Callable, str]],
path_routes: dict[Method, tuple[Callable, RouteName, RouteTags]],
components: dict[type, str],
path_param_parser: PathParamParser,
framework_req_cls: type | None,
Expand All @@ -332,6 +334,7 @@ def build_pathitem(
framework_resp_cls,
security_schemas,
summary_transformer,
list(get_route[2]),
)
if post_route := path_routes.get("POST"):
post = build_operation(
Expand All @@ -344,6 +347,7 @@ def build_pathitem(
framework_resp_cls,
security_schemas,
summary_transformer,
list(post_route[2]),
)
if put_route := path_routes.get("PUT"):
put = build_operation(
Expand All @@ -356,6 +360,7 @@ def build_pathitem(
framework_resp_cls,
security_schemas,
summary_transformer,
list(put_route[2]),
)
if patch_route := path_routes.get("PATCH"):
patch = build_operation(
Expand All @@ -368,6 +373,7 @@ def build_pathitem(
framework_resp_cls,
security_schemas,
summary_transformer,
list(patch_route[2]),
)
if delete_route := path_routes.get("DELETE"):
delete = build_operation(
Expand All @@ -380,6 +386,7 @@ def build_pathitem(
framework_resp_cls,
security_schemas,
summary_transformer,
list(delete_route[2]),
)
return OpenAPI.PathItem(get, post, put, patch, delete)

Expand All @@ -393,11 +400,13 @@ def routes_to_paths(
security_schemas: Mapping[str, ApiKeySecurityScheme],
summary_transformer: SummaryTransformer,
) -> dict[str, OpenAPI.PathItem]:
res: dict[str, dict[Method, tuple[Callable, str]]] = defaultdict(dict)
res: dict[str, dict[Method, tuple[Callable, RouteName, RouteTags]]] = defaultdict(
dict
)

for (method, path), (handler, name) in routes.items():
for (method, path), (handler, name, tags) in routes.items():
path = path_param_parser(path)[0]
res[path] = res[path] | {method: (handler, name)}
res[path] = res[path] | {method: (handler, name, tags)}

return {
k: build_pathitem(
Expand Down Expand Up @@ -480,7 +489,7 @@ def components_to_openapi(
"""
# First pass, we build the component registry.
components: dict[type, str] = {}
for handler, _ in routes.values():
for handler, _, _ in routes.values():
gather_endpoint_components(handler, components)

res: dict[str, AnySchema | Reference] = {}
Expand Down
2 changes: 1 addition & 1 deletion src/uapi/quart.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class QuartApp(BaseApp):
def to_framework_app(self, import_name: str) -> Quart:
q = Quart(import_name)

for (method, path), (handler, name) in self._route_map.items():
for (method, path), (handler, name, _) in self._route_map.items():
ra = make_return_adapter(
signature(handler, eval_str=True).return_annotation,
FrameworkResponse,
Expand Down
2 changes: 1 addition & 1 deletion src/uapi/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class StarletteApp(BaseApp):
def to_framework_app(self) -> Starlette:
s = Starlette()

for (method, path), (handler, name) in self._route_map.items():
for (method, path), (handler, name, _) in self._route_map.items():
ra = make_return_adapter(
signature(handler, eval_str=True).return_annotation,
FrameworkResponse,
Expand Down
6 changes: 4 additions & 2 deletions src/uapi/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Callable, Literal, TypeAlias, TypeVar
from typing import Callable, Literal, Sequence, TypeAlias, TypeVar

R = TypeVar("R")
CB = Callable[..., R]

RouteName: TypeAlias = str
RouteTags: TypeAlias = Sequence[str]
Method: TypeAlias = Literal["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]
Routes: TypeAlias = dict[tuple[Method, str], tuple[Callable, str]]
Routes: TypeAlias = dict[tuple[Method, str], tuple[Callable, RouteName, RouteTags]]
PathParamParser: TypeAlias = Callable[[str], tuple[str, list[str]]]


Expand Down
8 changes: 4 additions & 4 deletions tests/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ async def path_param(path_id: int) -> Response:
async def unannotated_exception() -> Response:
raise ResponseException(NoContent())

@app.get("/query/unannotated")
@app.get("/query/unannotated", tags=["query"])
async def query_unannotated(query) -> Response:
return Response(text=query + "suffix")

@app.get("/query/string")
@app.get("/query/string", tags=["query"])
async def query_string(query: str) -> Response:
return Response(text=query + "suffix")

@app.get("/query")
@app.get("/query", tags=["query"])
async def query_param(page: int) -> Response:
return Response(text=str(page + 1))

@app.get("/query-default")
@app.get("/query-default", tags=["query"])
async def query_default(page: int = 0) -> Response:
return Response(text=str(page + 1))

Expand Down
4 changes: 2 additions & 2 deletions tests/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def configure_base_async(app: App) -> None:
async def hello() -> str:
return "Hello, world"

@app.get("/query-bytes")
@app.get("/query-bytes", tags=["query"])
async def query_bytes() -> bytes:
return b"2"

Expand Down Expand Up @@ -191,7 +191,7 @@ def configure_base_sync(app: App) -> None:
def hello() -> str:
return "Hello, world"

@app.get("/query-bytes")
@app.get("/query-bytes", tags=["query"])
def query_bytes() -> bytes:
return b"2"

Expand Down
Loading

0 comments on commit 5e49a51

Please sign in to comment.