diff --git a/README.md b/README.md
index cc0bf55..02e2e76 100644
--- a/README.md
+++ b/README.md
@@ -22,6 +22,9 @@ pip install git+https://github.com/abersheeran/rpc.py@setup.py
### Server side:
+
+Use ASGI
mode to register async def
...
+
```python
import uvicorn
from rpcpy import RPC
@@ -48,9 +51,13 @@ async def yield_data(max_num: int):
if __name__ == "__main__":
uvicorn.run(app, interface="asgi3", port=65432)
```
+
OR
+
+Use WSGI
mode to register def
...
+
```python
import uvicorn
from rpcpy import RPC
@@ -77,9 +84,15 @@ def yield_data(max_num: int):
if __name__ == "__main__":
uvicorn.run(app, interface="wsgi", port=65432)
```
+
### Client side:
+Notice: Regardless of whether the server uses the WSGI mode or the ASGI mode, the client can freely use the asynchronous or synchronous mode.
+
+
+Use httpx.Client()
mode to register def
...
+
```python
import httpx
from rpcpy.client import Client
@@ -101,9 +114,13 @@ def sayhi(name: str) -> str:
def yield_data(max_num: int):
yield
```
+
OR
+
+Use httpx.AsyncClient()
mode to register async def
...
+
```python
import httpx
from rpcpy.client import Client
@@ -125,6 +142,7 @@ async def sayhi(name: str) -> str:
async def yield_data(max_num: int):
yield
```
+
### Sub-route
@@ -142,6 +160,16 @@ RPC(serializer=JSONSerializer())
RPC(serializer=PickleSerializer())
```
+## Type hint and OpenAPI Doc
+
+Thanks to the great work of [pydantic](https://pydantic-docs.helpmanual.io/), which makes rpc.py allow you to use type annotation to annotate the types of function parameters and response values, and perform type verification and JSON serialization . At the same time, it is allowed to generate openapi documents for human reading.
+
+### OpenAPI Documents
+
+If you want to open the OpenAPI document, you need to initialize `RPC` like this `RPC(openapi={"title": "TITLE", "description": "DESCRIPTION", "version": "v1"})`.
+
+Then, visit the `"{prefix}openapi-docs"` of RPC and you will be able to see the automatically generated OpenAPI documentation. (If you do not set the `prefix`, the `prefix` is `"/"`)
+
## Limitations
Currently, function parameters must be serializable by `json`.
diff --git a/poetry.lock b/poetry.lock
index 47d467f..32fcae8 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -94,6 +94,15 @@ version = "2.4"
[package.dependencies]
immutables = ">=0.9"
+[[package]]
+category = "main"
+description = "A backport of the dataclasses module for Python 3.6"
+marker = "python_version < \"3.7\""
+name = "dataclasses"
+optional = true
+python-versions = "*"
+version = "0.6"
+
[[package]]
category = "dev"
description = "the modular source code checker: pep8 pyflakes and co"
@@ -310,6 +319,24 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
version = "2.6.0"
+[[package]]
+category = "main"
+description = "Data validation and settings management using python 3.6 type hinting"
+name = "pydantic"
+optional = true
+python-versions = ">=3.6"
+version = "1.6.1"
+
+[package.dependencies]
+[package.dependencies.dataclasses]
+python = "<3.7"
+version = ">=0.6"
+
+[package.extras]
+dotenv = ["python-dotenv (>=0.10.4)"]
+email = ["email-validator (>=1.0.3)"]
+typing_extensions = ["typing-extensions (>=3.7.2)"]
+
[[package]]
category = "dev"
description = "passive checker of Python programs"
@@ -423,7 +450,7 @@ python-versions = "*"
version = "1.4.1"
[[package]]
-category = "dev"
+category = "main"
description = "Backported and Experimental Type Hints for Python 3.5+"
name = "typing-extensions"
optional = false
@@ -453,10 +480,11 @@ testing = ["jaraco.itertools", "func-timeout"]
[extras]
client = ["httpx"]
-full = ["httpx"]
+full = ["httpx", "pydantic"]
+type = ["pydantic"]
[metadata]
-content-hash = "19da082139ff7259d57402210c6735934f87b62178832eec7552220cc18f3b79"
+content-hash = "3361e14f9f0aec6e206ed0395dcb93a164d70119be1c542665928a536bfbceee"
lock-version = "1.0"
python-versions = "^3.6"
@@ -496,6 +524,10 @@ colorama = [
contextvars = [
{file = "contextvars-2.4.tar.gz", hash = "sha256:f38c908aaa59c14335eeea12abea5f443646216c4e29380d7bf34d2018e2c39e"},
]
+dataclasses = [
+ {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
+ {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
+]
flake8 = [
{file = "flake8-3.8.3-py2.py3-none-any.whl", hash = "sha256:15e351d19611c887e482fb960eae4d44845013cc142d42896e9862f775d8cf5c"},
{file = "flake8-3.8.3.tar.gz", hash = "sha256:f04b9fcbac03b0a3e58c0ab3a0ecc462e023a9faf046d57794184028123aa208"},
@@ -598,6 +630,25 @@ pycodestyle = [
{file = "pycodestyle-2.6.0-py2.py3-none-any.whl", hash = "sha256:2295e7b2f6b5bd100585ebcb1f616591b652db8a741695b3d8f5d28bdc934367"},
{file = "pycodestyle-2.6.0.tar.gz", hash = "sha256:c58a7d2815e0e8d7972bf1803331fb0152f867bd89adf8a01dfd55085434192e"},
]
+pydantic = [
+ {file = "pydantic-1.6.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:418b84654b60e44c0cdd5384294b0e4bc1ebf42d6e873819424f3b78b8690614"},
+ {file = "pydantic-1.6.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:4900b8820b687c9a3ed753684337979574df20e6ebe4227381d04b3c3c628f99"},
+ {file = "pydantic-1.6.1-cp36-cp36m-manylinux2014_i686.whl", hash = "sha256:b49c86aecde15cde33835d5d6360e55f5e0067bb7143a8303bf03b872935c75b"},
+ {file = "pydantic-1.6.1-cp36-cp36m-manylinux2014_x86_64.whl", hash = "sha256:2de562a456c4ecdc80cf1a8c3e70c666625f7d02d89a6174ecf63754c734592e"},
+ {file = "pydantic-1.6.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f769141ab0abfadf3305d4fcf36660e5cf568a666dd3efab7c3d4782f70946b1"},
+ {file = "pydantic-1.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2dc946b07cf24bee4737ced0ae77e2ea6bc97489ba5a035b603bd1b40ad81f7e"},
+ {file = "pydantic-1.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:36dbf6f1be212ab37b5fda07667461a9219c956181aa5570a00edfb0acdfe4a1"},
+ {file = "pydantic-1.6.1-cp37-cp37m-manylinux2014_i686.whl", hash = "sha256:1783c1d927f9e1366e0e0609ae324039b2479a1a282a98ed6a6836c9ed02002c"},
+ {file = "pydantic-1.6.1-cp37-cp37m-manylinux2014_x86_64.whl", hash = "sha256:cf3933c98cb5e808b62fae509f74f209730b180b1e3c3954ee3f7949e083a7df"},
+ {file = "pydantic-1.6.1-cp37-cp37m-win_amd64.whl", hash = "sha256:f8af9b840a9074e08c0e6dc93101de84ba95df89b267bf7151d74c553d66833b"},
+ {file = "pydantic-1.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:40d765fa2d31d5be8e29c1794657ad46f5ee583a565c83cea56630d3ae5878b9"},
+ {file = "pydantic-1.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:3fa799f3cfff3e5f536cbd389368fc96a44bb30308f258c94ee76b73bd60531d"},
+ {file = "pydantic-1.6.1-cp38-cp38-manylinux2014_i686.whl", hash = "sha256:6c3f162ba175678218629f446a947e3356415b6b09122dcb364e58c442c645a7"},
+ {file = "pydantic-1.6.1-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:eb75dc1809875d5738df14b6566ccf9fd9c0bcde4f36b72870f318f16b9f5c20"},
+ {file = "pydantic-1.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:530d7222a2786a97bc59ee0e0ebbe23728f82974b1f1ad9a11cd966143410633"},
+ {file = "pydantic-1.6.1-py36.py37.py38-none-any.whl", hash = "sha256:b5b3489cb303d0f41ad4a7390cf606a5f2c7a94dcba20c051cd1c653694cb14d"},
+ {file = "pydantic-1.6.1.tar.gz", hash = "sha256:54122a8ed6b75fe1dd80797f8251ad2063ea348a03b77218d73ea9fe19bd4e73"},
+]
pyflakes = [
{file = "pyflakes-2.2.0-py2.py3-none-any.whl", hash = "sha256:0d94e0e05a19e57a99444b6ddcf9a6eb2e5c68d3ca1e98e90707af8152c90a92"},
{file = "pyflakes-2.2.0.tar.gz", hash = "sha256:35b2d75ee967ea93b55750aa9edbbf72813e06a66ba54438df2cfac9e3c27fc8"},
diff --git a/pyproject.toml b/pyproject.toml
index ad346a1..74890ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rpc.py"
-version = "0.2.2"
+version = "0.3.0"
description = "An easy-to-use and powerful RPC framework. Base WSGI & ASGI."
authors = ["abersheeran "]
readme = "README.md"
@@ -20,11 +20,14 @@ packages = [
[tool.poetry.dependencies]
python = "^3.6"
+typing-extensions = {version = "^3.7.4", python = "<3.8"}
httpx = {version = "^0.13.3", optional = true} # for client and test
+pydantic = {version = "^1.6.1", optional = true} # for openapi docs
[tool.poetry.extras]
client = ["httpx"]
-full = ["httpx"]
+type = ["pydantic"]
+full = ["httpx", "pydantic"]
[tool.poetry.dev-dependencies]
flake8 = "*"
diff --git a/rpcpy/application.py b/rpcpy/application.py
index 8eaaaac..fdea4d4 100644
--- a/rpcpy/application.py
+++ b/rpcpy/application.py
@@ -1,9 +1,11 @@
import typing
import inspect
+import copy
+import json
from base64 import b64encode
from types import FunctionType
-from rpcpy.types import Environ, StartResponse, Scope, Receive, Send
+from rpcpy.types import Environ, StartResponse, Scope, Receive, Send, TypedDict
from rpcpy.serializers import BaseSerializer, JSONSerializer
from rpcpy.asgi import (
Request as ASGIRequest,
@@ -15,10 +17,21 @@
Response as WSGIResponse,
EventResponse as WSGIEventResponse,
)
+from rpcpy.utils import set_type_model
+from rpcpy.utils.openapi import (
+ BaseModel,
+ schema_request_body,
+ schema_response,
+ TEMPLATE as OpenapiTemplate,
+)
__all__ = ["RPC"]
Function = typing.TypeVar("Function", FunctionType, FunctionType)
+MethodNotAllowedHttpCode = {
+ "GET": 404,
+ "HEAD": 404,
+} # https://developer.mozilla.org/zh-CN/docs/Web/HTTP/Status/405
class RPCMeta(type):
@@ -36,6 +49,9 @@ def __call__(cls, *args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return super().__call__(*args, **kwargs)
+OpenAPI = TypedDict("OpenAPI", {"title": str, "description": str, "version": str})
+
+
class RPC(metaclass=RPCMeta):
def __init__(
self,
@@ -43,24 +59,64 @@ def __init__(
prefix: str = "/",
mode: str = "WSGI",
serializer: BaseSerializer = JSONSerializer(),
+ openapi: OpenAPI = None,
):
assert mode in ("WSGI", "ASGI"), "mode must be in ('WSGI', 'ASGI')"
assert prefix.startswith("/") and prefix.endswith("/")
self.callbacks: typing.Dict[str, typing.Callable] = {}
self.prefix = prefix
self.serializer = serializer
+ self.openapi = openapi
def register(self, func: Function) -> Function:
self.callbacks[func.__name__] = func
+ set_type_model(func)
return func
+ def get_openapi_docs(self) -> dict:
+ openapi: dict = {
+ "openapi": "3.0.0",
+ "info": copy.deepcopy(self.openapi) or {},
+ "paths": {},
+ }
+
+ for name, callback in self.callbacks.items():
+ _ = {}
+ # summary and description
+ doc = callback.__doc__
+ if isinstance(doc, str):
+ doc = doc.strip()
+ _.update(
+ {
+ "summary": doc.splitlines()[0],
+ "description": "\n".join(doc.splitlines()[1:]).strip(),
+ }
+ )
+ # request body
+ body_doc = schema_request_body(getattr(callback, "__body_model__", None))
+ if body_doc:
+ _["requestBody"] = body_doc
+ # response & only 200
+ sig = inspect.signature(callback)
+ if (
+ sig.return_annotation != sig.empty
+ and inspect.isclass(sig.return_annotation)
+ and issubclass(sig.return_annotation, BaseModel)
+ ):
+ _["responses"] = {
+ 200: {"content": schema_response(sig.return_annotation)}
+ }
+ if _:
+ openapi["paths"][f"{self.prefix}{name}"] = {"post": _}
+
+ return openapi
+
class WSGIRPC(RPC):
def register(self, func: Function) -> Function:
if inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func):
raise TypeError("WSGI mode can only register synchronization functions.")
- self.callbacks[func.__name__] = func
- return func
+ return super().register(func)
def create_generator(
self, generator: typing.Generator
@@ -72,14 +128,31 @@ def __call__(
self, environ: Environ, start_response: StartResponse
) -> typing.Iterable[bytes]:
request = WSGIRequest(environ)
+ if self.openapi is not None and request.method == "GET":
+ if request.url.path[len(self.prefix) :] == "openapi-docs":
+ return WSGIResponse(
+ OpenapiTemplate, headers={"content-type": "text/html"},
+ )(environ, start_response)
+ elif request.url.path[len(self.prefix) :] == "get-openapi-docs":
+ return WSGIResponse(
+ json.dumps(self.get_openapi_docs(), ensure_ascii=False),
+ headers={"content-type": "application/json"},
+ )(environ, start_response)
+
if request.method != "POST":
- return WSGIResponse(status_code=405)(environ, start_response)
+ return WSGIResponse(
+ status_code=MethodNotAllowedHttpCode.get(request.method, 405)
+ )(environ, start_response)
content_type = request.headers["content-type"]
assert content_type == "application/json"
data = request.json
- result = self.callbacks[request.url.path[len(self.prefix) :]](**data)
+ callback = self.callbacks[request.url.path[len(self.prefix) :]]
+ if hasattr(callback, "__body_model__"):
+ result = callback(**getattr(callback, "__body_model__")(**data).dict())
+ else:
+ result = callback(**data)
if inspect.isgenerator(result):
return WSGIEventResponse(
@@ -89,7 +162,10 @@ def __call__(
return WSGIResponse(
self.serializer.encode(result),
- headers={"serializer": self.serializer.name},
+ headers={
+ "serializer": self.serializer.name,
+ "content-type": self.serializer.content_type,
+ },
)(environ, start_response)
@@ -97,8 +173,7 @@ class ASGIRPC(RPC):
def register(self, func: Function) -> Function:
if not (inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func)):
raise TypeError("ASGI mode can only register asynchronous functions.")
- self.callbacks[func.__name__] = func
- return func
+ return super().register(func)
async def create_generator(
self, generator: typing.AsyncGenerator
@@ -108,15 +183,31 @@ async def create_generator(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = ASGIRequest(scope, receive, send)
+ if self.openapi is not None and request.method == "GET":
+ if request.url.path[len(self.prefix) :] == "openapi-docs":
+ return await ASGIResponse(
+ OpenapiTemplate, headers={"content-type": "text/html"},
+ )(scope, receive, send)
+ elif request.url.path[len(self.prefix) :] == "get-openapi-docs":
+ return await ASGIResponse(
+ json.dumps(self.get_openapi_docs(), ensure_ascii=False),
+ headers={"content-type": "application/json"},
+ )(scope, receive, send)
if request.method != "POST":
- return await ASGIResponse(status_code=405)(scope, receive, send)
+ return await ASGIResponse(
+ status_code=MethodNotAllowedHttpCode.get(request.method, 405)
+ )(scope, receive, send)
content_type = request.headers["content-type"]
assert content_type == "application/json"
data = await request.json()
- result = self.callbacks[request.url.path[len(self.prefix) :]](**data)
+ callback = self.callbacks[request.url.path[len(self.prefix) :]]
+ if hasattr(callback, "__body_model__"):
+ result = callback(**getattr(callback, "__body_model__")(**data).dict())
+ else:
+ result = callback(**data)
if inspect.isasyncgen(result):
return await ASGIEventResponse(
@@ -126,5 +217,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await ASGIResponse(
self.serializer.encode(await result),
- headers={"serializer": self.serializer.name},
+ headers={
+ "serializer": self.serializer.name,
+ "content-type": self.serializer.content_type,
+ },
)(scope, receive, send)
diff --git a/rpcpy/client.py b/rpcpy/client.py
index 95a9fe4..11283ce 100644
--- a/rpcpy/client.py
+++ b/rpcpy/client.py
@@ -1,12 +1,14 @@
import typing
import inspect
import functools
+import json
from base64 import b64decode
from types import FunctionType
import httpx
from rpcpy.serializers import get_serializer
+from rpcpy.utils import set_type_model
__all__ = ["Client"]
@@ -24,12 +26,12 @@ def __init__(
def remote_call(self, func: Function) -> Function:
is_async = inspect.iscoroutinefunction(func) or inspect.isasyncgenfunction(func)
-
+ set_type_model(func) # try set `__body_model__`
if is_async:
- return self.async_remote_call(func)
- return self.sync_remote_call(func)
+ return self.__async_remote_call(func)
+ return self.__sync_remote_call(func)
- def async_remote_call(self, func: Function) -> Function:
+ def __async_remote_call(self, func: Function) -> Function:
if not self.is_async:
raise TypeError(
"Synchronization Client can only register synchronization functions."
@@ -41,9 +43,19 @@ def async_remote_call(self, func: Function) -> Function:
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
sig = inspect.signature(func)
bound_values = sig.bind(*args, **kwargs)
+ if hasattr(func, "__body_model__"):
+ post_json = (
+ getattr(func, "__body_model__")(**bound_values.arguments)
+ .json()
+ .encode("utf8")
+ )
+ else:
+ post_json = json.dumps(dict(**bound_values.arguments)).encode(
+ "utf8"
+ )
url = self.base_url + func.__name__
resp: httpx.Response = await self.client.post( # type: ignore
- url, json=dict(bound_values.arguments.items())
+ url, data=post_json, headers={"content-type": "application/json"}
)
resp.raise_for_status()
serializer = get_serializer(resp.headers)
@@ -55,9 +67,22 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
sig = inspect.signature(func)
bound_values = sig.bind(*args, **kwargs)
+ if hasattr(func, "__body_model__"):
+ post_json = (
+ getattr(func, "__body_model__")(**bound_values.arguments)
+ .json()
+ .encode("utf8")
+ )
+ else:
+ post_json = json.dumps(dict(**bound_values.arguments)).encode(
+ "utf8"
+ )
url = self.base_url + func.__name__
async with self.client.stream(
- "POST", url, json=dict(bound_values.arguments.items())
+ "POST",
+ url,
+ data=post_json,
+ headers={"content-type": "application/json"},
) as resp: # type: httpx.Response
serializer = get_serializer(resp.headers)
# I don't know how to solve this error:
@@ -70,7 +95,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return typing.cast(Function, wrapper)
- def sync_remote_call(self, func: Function) -> Function:
+ def __sync_remote_call(self, func: Function) -> Function:
if self.is_async:
raise TypeError(
"Asynchronous Client can only register asynchronous functions."
@@ -81,9 +106,19 @@ def sync_remote_call(self, func: Function) -> Function:
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
sig = inspect.signature(func)
bound_values = sig.bind(*args, **kwargs)
+ if hasattr(func, "__body_model__"):
+ post_json = (
+ getattr(func, "__body_model__")(**bound_values.arguments)
+ .json()
+ .encode("utf8")
+ )
+ else:
+ post_json = json.dumps(dict(**bound_values.arguments)).encode(
+ "utf8"
+ )
url = self.base_url + func.__name__
resp: httpx.Response = self.client.post( # type: ignore
- url, json=dict(bound_values.arguments.items()),
+ url, data=post_json, headers={"content-type": "application/json"}
)
resp.raise_for_status()
serializer = get_serializer(resp.headers)
@@ -95,9 +130,22 @@ def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
sig = inspect.signature(func)
bound_values = sig.bind(*args, **kwargs)
+ if hasattr(func, "__body_model__"):
+ post_json = (
+ getattr(func, "__body_model__")(**bound_values.arguments)
+ .json()
+ .encode("utf8")
+ )
+ else:
+ post_json = json.dumps(dict(**bound_values.arguments)).encode(
+ "utf8"
+ )
url = self.base_url + func.__name__
with self.client.stream(
- "POST", url, json=dict(bound_values.arguments.items())
+ "POST",
+ url,
+ data=post_json,
+ headers={"content-type": "application/json"},
) as resp: # type: httpx.Response
serializer = get_serializer(resp.headers)
for line in resp.iter_lines():
diff --git a/rpcpy/serializers.py b/rpcpy/serializers.py
index f81334c..a64cdde 100644
--- a/rpcpy/serializers.py
+++ b/rpcpy/serializers.py
@@ -12,6 +12,7 @@ class BaseSerializer(metaclass=ABCMeta):
"""
name: str
+ content_type: str
@abstractmethod
def encode(self, data: typing.Any) -> bytes:
@@ -28,6 +29,7 @@ def json_default(obj: typing.Any) -> typing.Any:
class JSONSerializer(BaseSerializer):
name = "json"
+ content_type = "application/json"
def __init__(self, default: typing.Callable = json_default) -> None:
self.default = default
@@ -41,6 +43,7 @@ def decode(self, data: bytes) -> typing.Any:
class PickleSerializer(BaseSerializer):
name = "pickle"
+ content_type = "application/x-pickle"
def encode(self, data: typing.Any) -> bytes:
return pickle.dumps(data)
diff --git a/rpcpy/types.py b/rpcpy/types.py
index 9a9e6f7..86c56a2 100644
--- a/rpcpy/types.py
+++ b/rpcpy/types.py
@@ -1,3 +1,4 @@
+import sys
from types import TracebackType
from typing import (
Any,
@@ -10,6 +11,11 @@
Awaitable,
)
+if sys.version_info[:2] < (3, 8):
+ from typing_extensions import TypedDict, Literal, Final, final
+else:
+ from typing import TypedDict, Literal, Final, final
+
__all__ = [
"Scope",
"Message",
@@ -20,6 +26,12 @@
"Environ",
"StartResponse",
"WSGIApp",
+] + [
+ # built-in types
+ "TypedDict",
+ "Literal",
+ "Final",
+ "final",
]
# ASGI
diff --git a/rpcpy/utils/__init__.py b/rpcpy/utils/__init__.py
index ec20a8a..e5de010 100644
--- a/rpcpy/utils/__init__.py
+++ b/rpcpy/utils/__init__.py
@@ -1,6 +1,10 @@
import typing
+import functools
+import inspect
from http import cookies as http_cookies
+from .openapi import create_model
+
def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
"""
@@ -25,37 +29,56 @@ def cookie_parser(cookie_string: str) -> typing.Dict[str, str]:
return cookie_dict
-class cached_property:
- """
- A property that is only computed once per instance and then replaces
- itself with an ordinary attribute. Deleting the attribute resets the
- property.
- """
+if typing.TYPE_CHECKING:
+ # https://github.com/python/mypy/issues/5107
+ # for mypy check and IDE support
+ cached_property = property
+else:
+
+ class cached_property:
+ """
+ A property that is only computed once per instance and then replaces
+ itself with an ordinary attribute. Deleting the attribute resets the
+ property.
+ """
- def __init__(self, func: typing.Callable) -> None:
- self.__doc__ = getattr(func, "__doc__")
- self.func = func
+ def __init__(self, func: typing.Callable) -> None:
+ self.func = func
+ functools.update_wrapper(self, func)
- def __get__(self, obj: typing.Any, cls: typing.Any) -> typing.Any:
- if obj is None:
- return self
- value = obj.__dict__[self.func.__name__] = self.func(obj)
- return value
+ def __get__(self, obj: typing.Any, cls: typing.Any) -> typing.Any:
+ if obj is None:
+ return self
+ value = obj.__dict__[self.func.__name__] = self.func(obj)
+ return value
-def merge_list(
- raw: typing.List[typing.Tuple[str, str]]
-) -> typing.Dict[str, typing.Union[typing.List[str], str]]:
+Function = typing.TypeVar("Function", typing.Callable, typing.Callable)
+
+
+def set_type_model(func: Function) -> Function:
"""
- If there are values with the same key value, they are merged into a List.
+ try generate request body model from type hint and default value
"""
- d: typing.Dict[str, typing.Union[typing.List[str], str]] = {}
- for k, v in raw:
- if k in d:
- if isinstance(d[k], list):
- typing.cast(typing.List, d[k]).append(v)
- else:
- d[k] = [typing.cast(str, d[k]), v]
+ sig = inspect.signature(func)
+ field_definitions = {}
+ for name, parameter in sig.parameters.items():
+ if (
+ parameter.annotation == parameter.empty
+ and parameter.default == parameter.empty
+ ):
+ # raise ValueError(
+ # f"You must specify the type for the parameter {func.__name__}:{name}."
+ # )
+ return func # Maybe the type hint should be mandatory? I'm not sure.
+ if parameter.annotation == parameter.empty:
+ field_definitions[name] = parameter.default
+ elif parameter.default == parameter.empty:
+ field_definitions[name] = (parameter.annotation, ...)
else:
- d[k] = v
- return d
+ field_definitions[name] = (parameter.annotation, parameter.default)
+ if field_definitions:
+ body_model = create_model("temporary", **field_definitions)
+ setattr(func, "__body_model__", body_model)
+
+ return func
diff --git a/rpcpy/utils/openapi.py b/rpcpy/utils/openapi.py
new file mode 100644
index 0000000..95e6bda
--- /dev/null
+++ b/rpcpy/utils/openapi.py
@@ -0,0 +1,97 @@
+from copy import deepcopy
+from typing import Any, Dict, Optional, Union, Sequence
+
+try:
+ from pydantic import create_model
+ from pydantic import BaseModel
+except ImportError:
+
+ def create_model(*args, **kwargs): # type: ignore
+ raise NotImplementedError()
+
+ BaseModel = None # type: ignore
+
+
+def replace_definitions(schema: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ replace $ref
+ """
+ schema = deepcopy(schema)
+
+ if schema.get("definitions") is not None:
+
+ def replace(value: Union[str, Sequence[Any], Dict[str, Any]]) -> None:
+ if isinstance(value, str):
+ return
+ elif isinstance(value, Sequence):
+ for _value in value:
+ replace(_value)
+ elif isinstance(value, Dict):
+ for _name in tuple(value.keys()):
+ if _name == "$ref":
+ define_schema = schema
+ for key in value["$ref"][2:].split("/"):
+ define_schema = define_schema[key]
+ # replace ref and del it
+ value.update(define_schema)
+ del value["$ref"]
+ else:
+ replace(value[_name])
+
+ replace(schema["definitions"])
+ replace(schema["properties"])
+ del schema["definitions"]
+
+ return schema
+
+
+def schema_request_body(body: BaseModel = None) -> Optional[Dict[str, Any]]:
+ if body is None:
+ return None
+
+ _schema = replace_definitions(body.schema())
+ del _schema["title"]
+
+ return {
+ "required": True,
+ "content": {"application/json": {"schema": _schema}},
+ }
+
+
+def schema_response(model: BaseModel) -> Dict[str, Any]:
+ return {"application/json": {"schema": replace_definitions(model.schema())}}
+
+
+TEMPLATE = """
+
+
+
+ OpenAPI power by rpc.py
+
+
+
+
+
+
+
+
+
+
+
+
+"""
diff --git a/rpcpy/version.py b/rpcpy/version.py
index 3985165..cfb0b13 100644
--- a/rpcpy/version.py
+++ b/rpcpy/version.py
@@ -1,3 +1,3 @@
-VERSION = (0, 2, 2)
+VERSION = (0, 3, 0)
__version__ = ".".join(map(str, VERSION))
diff --git a/tests/test_application.py b/tests/test_application.py
index ea16a27..eef5721 100644
--- a/tests/test_application.py
+++ b/tests/test_application.py
@@ -1,3 +1,4 @@
+import httpx
import pytest
from rpcpy.application import RPC, WSGIRPC, ASGIRPC
@@ -19,8 +20,12 @@ def sayhi(name: str) -> str:
async def async_sayhi(name: str) -> str:
return f"hi {name}"
+ with httpx.Client(app=rpc, base_url="http://testServer/") as client:
+ assert client.get("/openapi-docs").status_code == 404
-def test_asgirpc():
+
+@pytest.mark.asyncio
+async def test_asgirpc():
rpc = RPC(mode="ASGI")
assert isinstance(rpc, ASGIRPC)
@@ -35,3 +40,84 @@ async def sayhi(name: str) -> str:
@rpc.register
def sync_sayhi(name: str) -> str:
return f"hi {name}"
+
+ async with httpx.AsyncClient(app=rpc, base_url="http://testServer/") as client:
+ assert (await client.get("/openapi-docs")).status_code == 404
+
+
+def test_wsgi_openapi():
+ rpc = RPC(openapi={"title": "Title", "description": "Description", "version": "v1"})
+
+ @rpc.register
+ def sayhi(name: str) -> str:
+ return f"hi {name}"
+
+ assert rpc.get_openapi_docs() == {
+ "openapi": "3.0.0",
+ "info": {"description": "Description", "title": "Title", "version": "v1"},
+ "paths": {
+ "/sayhi": {
+ "post": {
+ "requestBody": {
+ "required": True,
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "properties": {
+ "name": {"title": "Name", "type": "string"}
+ },
+ "required": ["name"],
+ }
+ }
+ },
+ }
+ }
+ }
+ },
+ }
+
+ with httpx.Client(app=rpc, base_url="http://testServer/") as client:
+ assert client.get("/openapi-docs").status_code == 200
+ assert client.get("/get-openapi-docs").status_code == 200
+
+
+@pytest.mark.asyncio
+async def test_asgi_openapi():
+ rpc = RPC(
+ mode="ASGI",
+ openapi={"title": "Title", "description": "Description", "version": "v1"},
+ )
+
+ @rpc.register
+ async def sayhi(name: str) -> str:
+ return f"hi {name}"
+
+ assert rpc.get_openapi_docs() == {
+ "openapi": "3.0.0",
+ "info": {"description": "Description", "title": "Title", "version": "v1"},
+ "paths": {
+ "/sayhi": {
+ "post": {
+ "requestBody": {
+ "required": True,
+ "content": {
+ "application/json": {
+ "schema": {
+ "type": "object",
+ "properties": {
+ "name": {"title": "Name", "type": "string"}
+ },
+ "required": ["name"],
+ }
+ }
+ },
+ }
+ }
+ }
+ },
+ }
+
+ async with httpx.AsyncClient(app=rpc, base_url="http://testServer/") as client:
+ assert (await client.get("/openapi-docs")).status_code == 200
+ assert (await client.get("/get-openapi-docs")).status_code == 200