From 97ac15f7c8ef84fc28cfef3df79e5d599736e186 Mon Sep 17 00:00:00 2001 From: abersheeran Date: Sun, 9 Aug 2020 17:42:38 +0800 Subject: [PATCH] Add type hint support to use pydantic --- README.md | 28 +++++++++ poetry.lock | 57 ++++++++++++++++++- pyproject.toml | 7 ++- rpcpy/application.py | 116 ++++++++++++++++++++++++++++++++++---- rpcpy/client.py | 66 +++++++++++++++++++--- rpcpy/serializers.py | 3 + rpcpy/types.py | 12 ++++ rpcpy/utils/__init__.py | 77 ++++++++++++++++--------- rpcpy/utils/openapi.py | 97 +++++++++++++++++++++++++++++++ rpcpy/version.py | 2 +- tests/test_application.py | 88 ++++++++++++++++++++++++++++- 11 files changed, 499 insertions(+), 54 deletions(-) create mode 100644 rpcpy/utils/openapi.py diff --git a/README.md b/README.md index cc0bf55..02e2e76 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,9 @@ pip install git+https://github.com/abersheeran/rpc.py@setup.py ### Server side: +
+Use ASGI mode to register async def... + ```python import uvicorn from rpcpy import RPC @@ -48,9 +51,13 @@ async def yield_data(max_num: int): if __name__ == "__main__": uvicorn.run(app, interface="asgi3", port=65432) ``` +
OR +
+Use WSGI mode to register def... + ```python import uvicorn from rpcpy import RPC @@ -77,9 +84,15 @@ def yield_data(max_num: int): if __name__ == "__main__": uvicorn.run(app, interface="wsgi", port=65432) ``` +
### Client side: +Notice: Regardless of whether the server uses the WSGI mode or the ASGI mode, the client can freely use the asynchronous or synchronous mode. + +
+Use httpx.Client() mode to register def... + ```python import httpx from rpcpy.client import Client @@ -101,9 +114,13 @@ def sayhi(name: str) -> str: def yield_data(max_num: int): yield ``` +
OR +
+Use httpx.AsyncClient() mode to register async def... + ```python import httpx from rpcpy.client import Client @@ -125,6 +142,7 @@ async def sayhi(name: str) -> str: async def yield_data(max_num: int): yield ``` +
### Sub-route @@ -142,6 +160,16 @@ RPC(serializer=JSONSerializer()) RPC(serializer=PickleSerializer()) ``` +## Type hint and OpenAPI Doc + +Thanks to the great work of [pydantic](https://pydantic-docs.helpmanual.io/), which makes rpc.py allow you to use type annotation to annotate the types of function parameters and response values, and perform type verification and JSON serialization . At the same time, it is allowed to generate openapi documents for human reading. + +### OpenAPI Documents + +If you want to open the OpenAPI document, you need to initialize `RPC` like this `RPC(openapi={"title": "TITLE", "description": "DESCRIPTION", "version": "v1"})`. + +Then, visit the `"{prefix}openapi-docs"` of RPC and you will be able to see the automatically generated OpenAPI documentation. (If you do not set the `prefix`, the `prefix` is `"/"`) + ## Limitations Currently, function parameters must be serializable by `json`. diff --git a/poetry.lock b/poetry.lock index 47d467f..32fcae8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -94,6 +94,15 @@ version = "2.4" [package.dependencies] immutables = ">=0.9" +[[package]] +category = "main" +description = "A backport of the dataclasses module for Python 3.6" +marker = "python_version < \"3.7\"" +name = "dataclasses" +optional = true +python-versions = "*" +version = "0.6" + [[package]] category = "dev" description = "the modular source code checker: pep8 pyflakes and co" @@ -310,6 +319,24 @@ optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "2.6.0" +[[package]] +category = "main" +description = "Data validation and settings management using python 3.6 type hinting" +name = "pydantic" +optional = true +python-versions = ">=3.6" +version = "1.6.1" + +[package.dependencies] +[package.dependencies.dataclasses] +python = "<3.7" +version = ">=0.6" + +[package.extras] +dotenv = ["python-dotenv (>=0.10.4)"] +email = ["email-validator (>=1.0.3)"] +typing_extensions = ["typing-extensions (>=3.7.2)"] + [[package]] category = "dev" description = "passive checker of Python programs" @@ -423,7 +450,7 @@ python-versions = "*" version = "1.4.1" [[package]] -category = "dev" +category = "main" description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" optional = false @@ -453,10 +480,11 @@ testing = ["jaraco.itertools", "func-timeout"] [extras] client = ["httpx"] -full = ["httpx"] +full = ["httpx", "pydantic"] +type = ["pydantic"] [metadata] -content-hash = "19da082139ff7259d57402210c6735934f87b62178832eec7552220cc18f3b79" +content-hash = "3361e14f9f0aec6e206ed0395dcb93a164d70119be1c542665928a536bfbceee" lock-version = "1.0" python-versions = "^3.6" @@ -496,6 +524,10 @@ colorama = [ contextvars = [ {file = "contextvars-2.4.tar.gz", hash = "sha256:f38c908aaa59c14335eeea12abea5f443646216c4e29380d7bf34d2018e2c39e"}, ] +dataclasses = [ + {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"}, + {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"}, +] flake8 = [ {file = "flake8-3.8.3-py2.py3-none-any.whl", hash = "sha256:15e351d19611c887e482fb960eae4d44845013cc142d42896e9862f775d8cf5c"}, {file = "flake8-3.8.3.tar.gz", hash = "sha256:f04b9fcbac03b0a3e58c0ab3a0ecc462e023a9faf046d57794184028123aa208"}, @@ -598,6 +630,25 @@ pycodestyle = [ {file = "pycodestyle-2.6.0-py2.py3-none-any.whl", hash = "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367"}, {file = "pycodestyle-2.6.0.tar.gz", hash = "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"}, ] +pydantic = [ + {file = "pydantic-1.6.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:418b84654b60e44c0cdd5384294b0e4bc1ebf42d6e873819424f3b78b8690614"}, + {file = "pydantic-1.6.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:4900b8820b687c9a3ed753684337979574df20e6ebe4227381d04b3c3c628f99"}, + {file = "pydantic-1.6.1-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:b49c86aecde15cde33835d5d6360e55f5e0067bb7143a8303bf03b872935c75b"}, + {file = "pydantic-1.6.1-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:2de562a456c4ecdc80cf1a8c3e70c666625f7d02d89a6174ecf63754c734592e"}, + {file = "pydantic-1.6.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f769141ab0abfadf3305d4fcf36660e5cf568a666dd3efab7c3d4782f70946b1"}, + {file = "pydantic-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2dc946b07cf24bee4737ced0ae77e2ea6bc97489ba5a035b603bd1b40ad81f7e"}, + {file = "pydantic-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:36dbf6f1be212ab37b5fda07667461a9219c956181aa5570a00edfb0acdfe4a1"}, + {file = "pydantic-1.6.1-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:1783c1d927f9e1366e0e0609ae324039b2479a1a282a98ed6a6836c9ed02002c"}, + {file = "pydantic-1.6.1-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:cf3933c98cb5e808b62fae509f74f209730b180b1e3c3954ee3f7949e083a7df"}, + {file = "pydantic-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f8af9b840a9074e08c0e6dc93101de84ba95df89b267bf7151d74c553d66833b"}, + {file = "pydantic-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:40d765fa2d31d5be8e29c1794657ad46f5ee583a565c83cea56630d3ae5878b9"}, + {file = "pydantic-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3fa799f3cfff3e5f536cbd389368fc96a44bb30308f258c94ee76b73bd60531d"}, + {file = "pydantic-1.6.1-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:6c3f162ba175678218629f446a947e3356415b6b09122dcb364e58c442c645a7"}, + {file = "pydantic-1.6.1-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:eb75dc1809875d5738df14b6566ccf9fd9c0bcde4f36b72870f318f16b9f5c20"}, + {file = "pydantic-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:530d7222a2786a97bc59ee0e0ebbe23728f82974b1f1ad9a11cd966143410633"}, + {file = "pydantic-1.6.1-py36.py37.py38-none-any.whl", hash = "sha256:b5b3489cb303d0f41ad4a7390cf606a5f2c7a94dcba20c051cd1c653694cb14d"}, + {file = "pydantic-1.6.1.tar.gz", hash = "sha256:54122a8ed6b75fe1dd80797f8251ad2063ea348a03b77218d73ea9fe19bd4e73"}, +] pyflakes = [ {file = "pyflakes-2.2.0-py2.py3-none-any.whl", hash = "sha256:0d94e0e05a19e57a99444b6ddcf9a6eb2e5c68d3ca1e98e90707af8152c90a92"}, {file = "pyflakes-2.2.0.tar.gz", hash = "sha256:35b2d75ee967ea93b55750aa9edbbf72813e06a66ba54438df2cfac9e3c27fc8"}, diff --git a/pyproject.toml b/pyproject.toml index ad346a1..74890ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "rpc.py" -version = "0.2.2" +version = "0.3.0" description = "An easy-to-use and powerful RPC framework. Base WSGI & ASGI." authors = ["abersheeran "] readme = "README.md" @@ -20,11 +20,14 @@ packages = [ [tool.poetry.dependencies] python = "^3.6" +typing-extensions = {version = "^3.7.4", python = "<3.8"} httpx = {version = "^0.13.3", optional = true} # for client and test +pydantic = {version = "^1.6.1", optional = true} # for openapi docs [tool.poetry.extras] client = ["httpx"] -full = ["httpx"] +type = ["pydantic"] +full = ["httpx", "pydantic"] [tool.poetry.dev-dependencies] flake8 = "*" diff --git a/rpcpy/application.py b/rpcpy/application.py index 8eaaaac..fdea4d4 100644 --- a/rpcpy/application.py +++ b/rpcpy/application.py @@ -1,9 +1,11 @@ import typing import inspect +import copy +import json from base64 import b64encode from types import FunctionType -from rpcpy.types import Environ, StartResponse, Scope, Receive, Send +from rpcpy.types import Environ, StartResponse, Scope, Receive, Send, TypedDict from rpcpy.serializers import BaseSerializer, JSONSerializer from rpcpy.asgi import ( Request as ASGIRequest, @@ -15,10 +17,21 @@ Response as WSGIResponse, EventResponse as WSGIEventResponse, ) +from rpcpy.utils import set_type_model +from rpcpy.utils.openapi import ( + BaseModel, + schema_request_body, + schema_response, + TEMPLATE as OpenapiTemplate, +) __all__ = ["RPC"] Function = typing.TypeVar("Function", FunctionType, FunctionType) +MethodNotAllowedHttpCode = { + "GET": 404, + "HEAD": 404, +} # https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Status/405 class RPCMeta(type): @@ -36,6 +49,9 @@ def __call__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: return super().__call__(*args, **kwargs) +OpenAPI = TypedDict("OpenAPI", {"title": str, "description": str, "version": str}) + + class RPC(metaclass=RPCMeta): def __init__( self, @@ -43,24 +59,64 @@ def __init__( prefix: str = "/", mode: str = "WSGI", serializer: BaseSerializer = JSONSerializer(), + openapi: OpenAPI = None, ): assert mode in ("WSGI", "ASGI"), "mode must be in ('WSGI', 'ASGI')" assert prefix.startswith("/") and prefix.endswith("/") self.callbacks: typing.Dict[str, typing.Callable] = {} self.prefix = prefix self.serializer = serializer + self.openapi = openapi def register(self, func: Function) -> Function: self.callbacks[func.__name__] = func + set_type_model(func) return func + def get_openapi_docs(self) -> dict: + openapi: dict = { + "openapi": "3.0.0", + "info": copy.deepcopy(self.openapi) or {}, + "paths": {}, + } + + for name, callback in self.callbacks.items(): + _ = {} + # summary and description + doc = callback.__doc__ + if isinstance(doc, str): + doc = doc.strip() + _.update( + { + "summary": doc.splitlines()[0], + "description": "\n".join(doc.splitlines()[1:]).strip(), + } + ) + # request body + body_doc = schema_request_body(getattr(callback, "__body_model__", None)) + if body_doc: + _["requestBody"] = body_doc + # response & only 200 + sig = inspect.signature(callback) + if ( + sig.return_annotation != sig.empty + and inspect.isclass(sig.return_annotation) + and issubclass(sig.return_annotation, BaseModel) + ): + _["responses"] = { + 200: {"content": schema_response(sig.return_annotation)} + } + if _: + openapi["paths"][f"{self.prefix}{name}"] = {"post": _} + + return openapi + class WSGIRPC(RPC): def register(self, func: Function) -> Function: if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): raise TypeError("WSGI mode can only register synchronization functions.") - self.callbacks[func.__name__] = func - return func + return super().register(func) def create_generator( self, generator: typing.Generator @@ -72,14 +128,31 @@ def __call__( self, environ: Environ, start_response: StartResponse ) -> typing.Iterable[bytes]: request = WSGIRequest(environ) + if self.openapi is not None and request.method == "GET": + if request.url.path[len(self.prefix) :] == "openapi-docs": + return WSGIResponse( + OpenapiTemplate, headers={"content-type": "text/html"}, + )(environ, start_response) + elif request.url.path[len(self.prefix) :] == "get-openapi-docs": + return WSGIResponse( + json.dumps(self.get_openapi_docs(), ensure_ascii=False), + headers={"content-type": "application/json"}, + )(environ, start_response) + if request.method != "POST": - return WSGIResponse(status_code=405)(environ, start_response) + return WSGIResponse( + status_code=MethodNotAllowedHttpCode.get(request.method, 405) + )(environ, start_response) content_type = request.headers["content-type"] assert content_type == "application/json" data = request.json - result = self.callbacks[request.url.path[len(self.prefix) :]](**data) + callback = self.callbacks[request.url.path[len(self.prefix) :]] + if hasattr(callback, "__body_model__"): + result = callback(**getattr(callback, "__body_model__")(**data).dict()) + else: + result = callback(**data) if inspect.isgenerator(result): return WSGIEventResponse( @@ -89,7 +162,10 @@ def __call__( return WSGIResponse( self.serializer.encode(result), - headers={"serializer": self.serializer.name}, + headers={ + "serializer": self.serializer.name, + "content-type": self.serializer.content_type, + }, )(environ, start_response) @@ -97,8 +173,7 @@ class ASGIRPC(RPC): def register(self, func: Function) -> Function: if not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func)): raise TypeError("ASGI mode can only register asynchronous functions.") - self.callbacks[func.__name__] = func - return func + return super().register(func) async def create_generator( self, generator: typing.AsyncGenerator @@ -108,15 +183,31 @@ async def create_generator( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = ASGIRequest(scope, receive, send) + if self.openapi is not None and request.method == "GET": + if request.url.path[len(self.prefix) :] == "openapi-docs": + return await ASGIResponse( + OpenapiTemplate, headers={"content-type": "text/html"}, + )(scope, receive, send) + elif request.url.path[len(self.prefix) :] == "get-openapi-docs": + return await ASGIResponse( + json.dumps(self.get_openapi_docs(), ensure_ascii=False), + headers={"content-type": "application/json"}, + )(scope, receive, send) if request.method != "POST": - return await ASGIResponse(status_code=405)(scope, receive, send) + return await ASGIResponse( + status_code=MethodNotAllowedHttpCode.get(request.method, 405) + )(scope, receive, send) content_type = request.headers["content-type"] assert content_type == "application/json" data = await request.json() - result = self.callbacks[request.url.path[len(self.prefix) :]](**data) + callback = self.callbacks[request.url.path[len(self.prefix) :]] + if hasattr(callback, "__body_model__"): + result = callback(**getattr(callback, "__body_model__")(**data).dict()) + else: + result = callback(**data) if inspect.isasyncgen(result): return await ASGIEventResponse( @@ -126,5 +217,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return await ASGIResponse( self.serializer.encode(await result), - headers={"serializer": self.serializer.name}, + headers={ + "serializer": self.serializer.name, + "content-type": self.serializer.content_type, + }, )(scope, receive, send) diff --git a/rpcpy/client.py b/rpcpy/client.py index 95a9fe4..11283ce 100644 --- a/rpcpy/client.py +++ b/rpcpy/client.py @@ -1,12 +1,14 @@ import typing import inspect import functools +import json from base64 import b64decode from types import FunctionType import httpx from rpcpy.serializers import get_serializer +from rpcpy.utils import set_type_model __all__ = ["Client"] @@ -24,12 +26,12 @@ def __init__( def remote_call(self, func: Function) -> Function: is_async = inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func) - + set_type_model(func) # try set `__body_model__` if is_async: - return self.async_remote_call(func) - return self.sync_remote_call(func) + return self.__async_remote_call(func) + return self.__sync_remote_call(func) - def async_remote_call(self, func: Function) -> Function: + def __async_remote_call(self, func: Function) -> Function: if not self.is_async: raise TypeError( "Synchronization Client can only register synchronization functions." @@ -41,9 +43,19 @@ def async_remote_call(self, func: Function) -> Function: async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: sig = inspect.signature(func) bound_values = sig.bind(*args, **kwargs) + if hasattr(func, "__body_model__"): + post_json = ( + getattr(func, "__body_model__")(**bound_values.arguments) + .json() + .encode("utf8") + ) + else: + post_json = json.dumps(dict(**bound_values.arguments)).encode( + "utf8" + ) url = self.base_url + func.__name__ resp: httpx.Response = await self.client.post( # type: ignore - url, json=dict(bound_values.arguments.items()) + url, data=post_json, headers={"content-type": "application/json"} ) resp.raise_for_status() serializer = get_serializer(resp.headers) @@ -55,9 +67,22 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: sig = inspect.signature(func) bound_values = sig.bind(*args, **kwargs) + if hasattr(func, "__body_model__"): + post_json = ( + getattr(func, "__body_model__")(**bound_values.arguments) + .json() + .encode("utf8") + ) + else: + post_json = json.dumps(dict(**bound_values.arguments)).encode( + "utf8" + ) url = self.base_url + func.__name__ async with self.client.stream( - "POST", url, json=dict(bound_values.arguments.items()) + "POST", + url, + data=post_json, + headers={"content-type": "application/json"}, ) as resp: # type: httpx.Response serializer = get_serializer(resp.headers) # I don't know how to solve this error: @@ -70,7 +95,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: return typing.cast(Function, wrapper) - def sync_remote_call(self, func: Function) -> Function: + def __sync_remote_call(self, func: Function) -> Function: if self.is_async: raise TypeError( "Asynchronous Client can only register asynchronous functions." @@ -81,9 +106,19 @@ def sync_remote_call(self, func: Function) -> Function: def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: sig = inspect.signature(func) bound_values = sig.bind(*args, **kwargs) + if hasattr(func, "__body_model__"): + post_json = ( + getattr(func, "__body_model__")(**bound_values.arguments) + .json() + .encode("utf8") + ) + else: + post_json = json.dumps(dict(**bound_values.arguments)).encode( + "utf8" + ) url = self.base_url + func.__name__ resp: httpx.Response = self.client.post( # type: ignore - url, json=dict(bound_values.arguments.items()), + url, data=post_json, headers={"content-type": "application/json"} ) resp.raise_for_status() serializer = get_serializer(resp.headers) @@ -95,9 +130,22 @@ def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: sig = inspect.signature(func) bound_values = sig.bind(*args, **kwargs) + if hasattr(func, "__body_model__"): + post_json = ( + getattr(func, "__body_model__")(**bound_values.arguments) + .json() + .encode("utf8") + ) + else: + post_json = json.dumps(dict(**bound_values.arguments)).encode( + "utf8" + ) url = self.base_url + func.__name__ with self.client.stream( - "POST", url, json=dict(bound_values.arguments.items()) + "POST", + url, + data=post_json, + headers={"content-type": "application/json"}, ) as resp: # type: httpx.Response serializer = get_serializer(resp.headers) for line in resp.iter_lines(): diff --git a/rpcpy/serializers.py b/rpcpy/serializers.py index f81334c..a64cdde 100644 --- a/rpcpy/serializers.py +++ b/rpcpy/serializers.py @@ -12,6 +12,7 @@ class BaseSerializer(metaclass=ABCMeta): """ name: str + content_type: str @abstractmethod def encode(self, data: typing.Any) -> bytes: @@ -28,6 +29,7 @@ def json_default(obj: typing.Any) -> typing.Any: class JSONSerializer(BaseSerializer): name = "json" + content_type = "application/json" def __init__(self, default: typing.Callable = json_default) -> None: self.default = default @@ -41,6 +43,7 @@ def decode(self, data: bytes) -> typing.Any: class PickleSerializer(BaseSerializer): name = "pickle" + content_type = "application/x-pickle" def encode(self, data: typing.Any) -> bytes: return pickle.dumps(data) diff --git a/rpcpy/types.py b/rpcpy/types.py index 9a9e6f7..86c56a2 100644 --- a/rpcpy/types.py +++ b/rpcpy/types.py @@ -1,3 +1,4 @@ +import sys from types import TracebackType from typing import ( Any, @@ -10,6 +11,11 @@ Awaitable, ) +if sys.version_info[:2] < (3, 8): + from typing_extensions import TypedDict, Literal, Final, final +else: + from typing import TypedDict, Literal, Final, final + __all__ = [ "Scope", "Message", @@ -20,6 +26,12 @@ "Environ", "StartResponse", "WSGIApp", +] + [ + # built-in types + "TypedDict", + "Literal", + "Final", + "final", ] # ASGI diff --git a/rpcpy/utils/__init__.py b/rpcpy/utils/__init__.py index ec20a8a..e5de010 100644 --- a/rpcpy/utils/__init__.py +++ b/rpcpy/utils/__init__.py @@ -1,6 +1,10 @@ import typing +import functools +import inspect from http import cookies as http_cookies +from .openapi import create_model + def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: """ @@ -25,37 +29,56 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]: return cookie_dict -class cached_property: - """ - A property that is only computed once per instance and then replaces - itself with an ordinary attribute. Deleting the attribute resets the - property. - """ +if typing.TYPE_CHECKING: + # https://github.com/python/mypy/issues/5107 + # for mypy check and IDE support + cached_property = property +else: + + class cached_property: + """ + A property that is only computed once per instance and then replaces + itself with an ordinary attribute. Deleting the attribute resets the + property. + """ - def __init__(self, func: typing.Callable) -> None: - self.__doc__ = getattr(func, "__doc__") - self.func = func + def __init__(self, func: typing.Callable) -> None: + self.func = func + functools.update_wrapper(self, func) - def __get__(self, obj: typing.Any, cls: typing.Any) -> typing.Any: - if obj is None: - return self - value = obj.__dict__[self.func.__name__] = self.func(obj) - return value + def __get__(self, obj: typing.Any, cls: typing.Any) -> typing.Any: + if obj is None: + return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value -def merge_list( - raw: typing.List[typing.Tuple[str, str]] -) -> typing.Dict[str, typing.Union[typing.List[str], str]]: +Function = typing.TypeVar("Function", typing.Callable, typing.Callable) + + +def set_type_model(func: Function) -> Function: """ - If there are values with the same key value, they are merged into a List. + try generate request body model from type hint and default value """ - d: typing.Dict[str, typing.Union[typing.List[str], str]] = {} - for k, v in raw: - if k in d: - if isinstance(d[k], list): - typing.cast(typing.List, d[k]).append(v) - else: - d[k] = [typing.cast(str, d[k]), v] + sig = inspect.signature(func) + field_definitions = {} + for name, parameter in sig.parameters.items(): + if ( + parameter.annotation == parameter.empty + and parameter.default == parameter.empty + ): + # raise ValueError( + # f"You must specify the type for the parameter {func.__name__}:{name}." + # ) + return func # Maybe the type hint should be mandatory? I'm not sure. + if parameter.annotation == parameter.empty: + field_definitions[name] = parameter.default + elif parameter.default == parameter.empty: + field_definitions[name] = (parameter.annotation, ...) else: - d[k] = v - return d + field_definitions[name] = (parameter.annotation, parameter.default) + if field_definitions: + body_model = create_model("temporary", **field_definitions) + setattr(func, "__body_model__", body_model) + + return func diff --git a/rpcpy/utils/openapi.py b/rpcpy/utils/openapi.py new file mode 100644 index 0000000..95e6bda --- /dev/null +++ b/rpcpy/utils/openapi.py @@ -0,0 +1,97 @@ +from copy import deepcopy +from typing import Any, Dict, Optional, Union, Sequence + +try: + from pydantic import create_model + from pydantic import BaseModel +except ImportError: + + def create_model(*args, **kwargs): # type: ignore + raise NotImplementedError() + + BaseModel = None # type: ignore + + +def replace_definitions(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + replace $ref + """ + schema = deepcopy(schema) + + if schema.get("definitions") is not None: + + def replace(value: Union[str, Sequence[Any], Dict[str, Any]]) -> None: + if isinstance(value, str): + return + elif isinstance(value, Sequence): + for _value in value: + replace(_value) + elif isinstance(value, Dict): + for _name in tuple(value.keys()): + if _name == "$ref": + define_schema = schema + for key in value["$ref"][2:].split("/"): + define_schema = define_schema[key] + # replace ref and del it + value.update(define_schema) + del value["$ref"] + else: + replace(value[_name]) + + replace(schema["definitions"]) + replace(schema["properties"]) + del schema["definitions"] + + return schema + + +def schema_request_body(body: BaseModel = None) -> Optional[Dict[str, Any]]: + if body is None: + return None + + _schema = replace_definitions(body.schema()) + del _schema["title"] + + return { + "required": True, + "content": {"application/json": {"schema": _schema}}, + } + + +def schema_response(model: BaseModel) -> Dict[str, Any]: + return {"application/json": {"schema": replace_definitions(model.schema())}} + + +TEMPLATE = """ + + + + OpenAPI power by rpc.py + + + + + + + + + + + + +""" diff --git a/rpcpy/version.py b/rpcpy/version.py index 3985165..cfb0b13 100644 --- a/rpcpy/version.py +++ b/rpcpy/version.py @@ -1,3 +1,3 @@ -VERSION = (0, 2, 2) +VERSION = (0, 3, 0) __version__ = ".".join(map(str, VERSION)) diff --git a/tests/test_application.py b/tests/test_application.py index ea16a27..eef5721 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,3 +1,4 @@ +import httpx import pytest from rpcpy.application import RPC, WSGIRPC, ASGIRPC @@ -19,8 +20,12 @@ def sayhi(name: str) -> str: async def async_sayhi(name: str) -> str: return f"hi {name}" + with httpx.Client(app=rpc, base_url="http://testServer/") as client: + assert client.get("/openapi-docs").status_code == 404 -def test_asgirpc(): + +@pytest.mark.asyncio +async def test_asgirpc(): rpc = RPC(mode="ASGI") assert isinstance(rpc, ASGIRPC) @@ -35,3 +40,84 @@ async def sayhi(name: str) -> str: @rpc.register def sync_sayhi(name: str) -> str: return f"hi {name}" + + async with httpx.AsyncClient(app=rpc, base_url="http://testServer/") as client: + assert (await client.get("/openapi-docs")).status_code == 404 + + +def test_wsgi_openapi(): + rpc = RPC(openapi={"title": "Title", "description": "Description", "version": "v1"}) + + @rpc.register + def sayhi(name: str) -> str: + return f"hi {name}" + + assert rpc.get_openapi_docs() == { + "openapi": "3.0.0", + "info": {"description": "Description", "title": "Title", "version": "v1"}, + "paths": { + "/sayhi": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"} + }, + "required": ["name"], + } + } + }, + } + } + } + }, + } + + with httpx.Client(app=rpc, base_url="http://testServer/") as client: + assert client.get("/openapi-docs").status_code == 200 + assert client.get("/get-openapi-docs").status_code == 200 + + +@pytest.mark.asyncio +async def test_asgi_openapi(): + rpc = RPC( + mode="ASGI", + openapi={"title": "Title", "description": "Description", "version": "v1"}, + ) + + @rpc.register + async def sayhi(name: str) -> str: + return f"hi {name}" + + assert rpc.get_openapi_docs() == { + "openapi": "3.0.0", + "info": {"description": "Description", "title": "Title", "version": "v1"}, + "paths": { + "/sayhi": { + "post": { + "requestBody": { + "required": True, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "name": {"title": "Name", "type": "string"} + }, + "required": ["name"], + } + } + }, + } + } + } + }, + } + + async with httpx.AsyncClient(app=rpc, base_url="http://testServer/") as client: + assert (await client.get("/openapi-docs")).status_code == 200 + assert (await client.get("/get-openapi-docs")).status_code == 200