diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..7ef80e0
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,6 @@
+[run]
+omit =
+ zero/zeromq_patterns/factory.py
+ zero/zeromq_patterns/helpers.py
+ zero/logger.py
+ zero/rpc/protocols.py
\ No newline at end of file
diff --git a/Dockerfile.test.py310 b/Dockerfile.test.py310
deleted file mode 100644
index fc33837..0000000
--- a/Dockerfile.test.py310
+++ /dev/null
@@ -1,9 +0,0 @@
-FROM python:3.10-slim
-
-COPY tests/requirements.txt .
-RUN pip install -r requirements.txt
-
-COPY zero ./zero
-COPY tests ./tests
-
-CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"]
diff --git a/Dockerfile.test.py38 b/Dockerfile.test.py38
deleted file mode 100644
index 209a32e..0000000
--- a/Dockerfile.test.py38
+++ /dev/null
@@ -1,9 +0,0 @@
-FROM python:3.8-slim
-
-COPY tests/requirements.txt .
-RUN pip install -r requirements.txt
-
-COPY zero ./zero
-COPY tests ./tests
-
-CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"]
diff --git a/Dockerfile.test.py39 b/Dockerfile.test.py39
deleted file mode 100644
index 5e6c2fd..0000000
--- a/Dockerfile.test.py39
+++ /dev/null
@@ -1,9 +0,0 @@
-FROM python:3.9-slim
-
-COPY tests/requirements.txt .
-RUN pip install -r requirements.txt
-
-COPY zero ./zero
-COPY tests ./tests
-
-CMD ["pytest", "tests", "--cov=zero", "--cov-report=term-missing", "-vv"]
diff --git a/Makefile b/Makefile
index 413330e..ca8d107 100644
--- a/Makefile
+++ b/Makefile
@@ -9,7 +9,7 @@ setup:
)
test:
- python3 -m pytest tests --cov=zero --cov-report=term-missing -vv --durations=10 --timeout=280
+ python3 -m pytest tests --cov=zero --cov-report=term-missing --cov-config=.coveragerc -vv --durations=10 --timeout=280
docker-test:
docker build -t zero-test -f Dockerfile.test.py38 .
diff --git a/README.md b/README.md
index 1754f17..5535e97 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@
-
+
@@ -27,16 +27,16 @@
**Features**:
-* Zero provides **faster communication** (see [benchmarks](https://github.com/Ananto30/zero#benchmarks-)) between the microservices using [zeromq](https://zeromq.org/) under the hood.
-* Zero uses messages for communication and traditional **client-server** or **request-reply** pattern is supported.
-* Support for both **async** and **sync**.
-* The base server (ZeroServer) **utilizes all cpu cores**.
-* **Code generation**! See [example](https://github.com/Ananto30/zero#code-generation-) 👇
+* Zero provides **faster communication** (see [benchmarks](https://github.com/Ananto30/zero#benchmarks-)) between the microservices using [zeromq](https://zeromq.org/) under the hood.
+* Zero uses messages for communication and traditional **client-server** or **request-reply** pattern is supported.
+* Support for both **async** and **sync**.
+* The base server (ZeroServer) **utilizes all cpu cores**.
+* **Code generation**! See [example](https://github.com/Ananto30/zero#code-generation-) 👇
**Philosophy** behind Zero:
-* **Zero learning curve**: The learning curve is tends to zero. Just add functions and spin up a server, literally that's it! The framework hides the complexity of messaging pattern that enables faster communication.
-* **ZeroMQ**: An awesome messaging library enables the power of Zero.
+* **Zero learning curve**: The learning curve is tends to zero. Just add functions and spin up a server, literally that's it! The framework hides the complexity of messaging pattern that enables faster communication.
+* **ZeroMQ**: An awesome messaging library enables the power of Zero.
Let's get started!
@@ -44,83 +44,86 @@ Let's get started!
*Ensure Python 3.8+*
- pip install zeroapi
+```
+pip install zeroapi
+```
**For Windows**, [tornado](https://pypi.org/project/tornado/) needs to be installed separately (for async operations). It's not included with `zeroapi` because for linux and mac-os, tornado is not needed as they have their own event loops.
-* Create a `server.py`
+* Create a `server.py`
+
+ ```python
+ from zero import ZeroServer
- ```python
- from zero import ZeroServer
+ app = ZeroServer(port=5559)
- app = ZeroServer(port=5559)
+ @app.register_rpc
+ def echo(msg: str) -> str:
+ return msg
- @app.register_rpc
- def echo(msg: str) -> str:
- return msg
+ @app.register_rpc
+ async def hello_world() -> str:
+ return "hello world"
- @app.register_rpc
- async def hello_world() -> str:
- return "hello world"
+ if __name__ == "__main__":
+ app.run()
+ ```
- if __name__ == "__main__":
- app.run()
- ```
+* The **RPC functions only support one argument** (`msg`) for now.
-* The **RPC functions only support one argument** (`msg`) for now.
+* Also note that server **RPC functions are type hinted**. Type hint is **must** in Zero server. Supported types can be found [here](/zero/utils/type_util.py#L11).
-* Also note that server **RPC functions are type hinted**. Type hint is **must** in Zero server. Supported types can be found [here](/zero/utils/type_util.py#L11).
+* Run the server
-* Run the server
- ```shell
- python -m server
- ```
+ ```shell
+ python -m server
+ ```
-* Call the rpc methods
+* Call the rpc methods
- ```python
- from zero import ZeroClient
+ ```python
+ from zero import ZeroClient
- zero_client = ZeroClient("localhost", 5559)
+ zero_client = ZeroClient("localhost", 5559)
- def echo():
- resp = zero_client.call("echo", "Hi there!")
- print(resp)
+ def echo():
+ resp = zero_client.call("echo", "Hi there!")
+ print(resp)
- def hello():
- resp = zero_client.call("hello_world", None)
- print(resp)
+ def hello():
+ resp = zero_client.call("hello_world", None)
+ print(resp)
- if __name__ == "__main__":
- echo()
- hello()
- ```
+ if __name__ == "__main__":
+ echo()
+ hello()
+ ```
-* Or using async client -
+* Or using async client -
- ```python
- import asyncio
+ ```python
+ import asyncio
- from zero import AsyncZeroClient
+ from zero import AsyncZeroClient
- zero_client = AsyncZeroClient("localhost", 5559)
+ zero_client = AsyncZeroClient("localhost", 5559)
- async def echo():
- resp = await zero_client.call("echo", "Hi there!")
- print(resp)
+ async def echo():
+ resp = await zero_client.call("echo", "Hi there!")
+ print(resp)
- async def hello():
- resp = await zero_client.call("hello_world", None)
- print(resp)
+ async def hello():
+ resp = await zero_client.call("hello_world", None)
+ print(resp)
- if __name__ == "__main__":
- loop = asyncio.get_event_loop()
- loop.run_until_complete(echo())
- loop.run_until_complete(hello())
- ```
+ if __name__ == "__main__":
+ loop = asyncio.get_event_loop()
+ loop.run_until_complete(echo())
+ loop.run_until_complete(hello())
+ ```
# Serialization 📦
@@ -224,9 +227,9 @@ Currently, the code generation tool supports only `ZeroClient` and not `AsyncZer
# Important notes! 📝
-* `ZeroServer` should always be run under `if __name__ == "__main__":`, as it uses multiprocessing.
-* `ZeroServer` creates the workers in different processes, so anything global in your code will be instantiated N times where N is the number of workers. So if you want to initiate them once, put them under `if __name__ == "__main__":`. But recommended to not use global vars. And Databases, Redis, other clients, creating them N times in different processes is fine and preferred.
-* The methods which are under `register_rpc()` in `ZeroServer` should have **type hinting**, like `def echo(msg: str) -> str:`
+* `ZeroServer` should always be run under `if __name__ == "__main__":`, as it uses multiprocessing.
+* `ZeroServer` creates the workers in different processes, so anything global in your code will be instantiated N times where N is the number of workers. So if you want to initiate them once, put them under `if __name__ == "__main__":`. But recommended to not use global vars. And Databases, Redis, other clients, creating them N times in different processes is fine and preferred.
+* The methods which are under `register_rpc()` in `ZeroServer` should have **type hinting**, like `def echo(msg: str) -> str:`
# Let's do some benchmarking! 🏎
@@ -236,8 +239,8 @@ So we will be testing a gateway calling another server for some data. Check the
There are two endpoints in every tests,
-* `/hello`: Just call for a hello world response 😅
-* `/order`: Save a Order object in redis
+* `/hello`: Just call for a hello world response 😅
+* `/order`: Save a Order object in redis
Compare the results! 👇
@@ -247,23 +250,23 @@ Compare the results! 👇
*(Sorted alphabetically)*
-Framework | "hello world" (req/s) | 99% latency (ms) | redis save (req/s) | 99% latency (ms)
------------ | --------------------- | ---------------- | ------------------ | ----------------
-aiohttp | 14949.57 | 8.91 | 9753.87 | 13.75
-aiozmq | 13844.67 | 9.55 | 5239.14 | 30.92
-blacksheep | 32967.27 | 3.03 | 18010.67 | 6.79
-fastApi | 13154.96 | 9.07 | 8369.87 | 15.91
-sanic | 18793.08 | 5.88 | 12739.37 | 8.78
-zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69
-zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80
+| Framework | "hello world" (req/s) | 99% latency (ms) | redis save (req/s) | 99% latency (ms) |
+| ----------- | --------------------- | ---------------- | ------------------ | ---------------- |
+| aiohttp | 14949.57 | 8.91 | 9753.87 | 13.75 |
+| aiozmq | 13844.67 | 9.55 | 5239.14 | 30.92 |
+| blacksheep | 32967.27 | 3.03 | 18010.67 | 6.79 |
+| fastApi | 13154.96 | 9.07 | 8369.87 | 15.91 |
+| sanic | 18793.08 | 5.88 | 12739.37 | 8.78 |
+| zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69 |
+| zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80 |
Seems like blacksheep is faster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆
# Roadmap 🗺
-* [x] Make msgspec as default serializer
-* [ ] Add support for async server (currently the sync server runs async functions in the eventloop, which is blocking)
-* [ ] Add pub/sub support
+* \[x] Make msgspec as default serializer
+* \[ ] Add support for async server (currently the sync server runs async functions in the eventloop, which is blocking)
+* \[ ] Add pub/sub support
# Contribution
diff --git a/benchmarks/dockerize/README.md b/benchmarks/dockerize/README.md
index 412128a..55b808c 100644
--- a/benchmarks/dockerize/README.md
+++ b/benchmarks/dockerize/README.md
@@ -52,7 +52,6 @@ I have used 2x cpu threads so `-t 16` and 16x25 = 400 connections.
| sanic | 13195.99 | 20.04 | 7226.72 | 25.24 |
| zero | 18867.00 | 11.48 | 12293.81 | 11.68 |
-
## Old benchmark results
Intel Core i3 10100, 4 cores, 8 threads, 16GB RAM, with docker limits **cpu 40% and memory 256m**
@@ -67,7 +66,6 @@ Intel Core i3 10100, 4 cores, 8 threads, 16GB RAM, with docker limits **cpu 40%
| sanic | 3,085.80 req/s | 547.02 req/s |
| zero | 5,000.77 req/s | 784.51 req/s |
-
MacBook Pro (13-inch, M1, 2020), Apple M1, 8 cores (4 performance and 4 efficiency), 8 GB RAM
*(Sorted alphabetically)*
@@ -81,7 +79,6 @@ MacBook Pro (13-inch, M1, 2020), Apple M1, 8 cores (4 performance and 4 efficien
More about MacBook benchmarks [here](https://github.com/Ananto30/zero/blob/main/benchmarks/others/mac-results.md)
-
-### Note!
+### Note
Please note that sometimes just `docker-compose up` will not run the `wrk`. Because you know about the docker `depends_on` only ensures the service is up, not running or healthy. So you may need to run wrk service after other services are up and running.
diff --git a/examples/basic/schema.py b/examples/basic/schema.py
index ecb8e66..9e4057b 100644
--- a/examples/basic/schema.py
+++ b/examples/basic/schema.py
@@ -1,9 +1,30 @@
+from dataclasses import dataclass
+from datetime import date
from typing import List
import msgspec
+class Address(msgspec.Struct):
+ street: str
+ city: str
+ zip: int
+
+
class User(msgspec.Struct):
name: str
age: int
emails: List[str]
+ addresses: List[Address]
+ registered_at: date
+
+
+@dataclass
+class Teacher:
+ name: str
+
+
+class Student(User):
+ roll_no: int
+ marks: List[int]
+ teachers: List[Teacher]
diff --git a/examples/basic/server.py b/examples/basic/server.py
index ffc3c36..2bac606 100644
--- a/examples/basic/server.py
+++ b/examples/basic/server.py
@@ -5,7 +5,7 @@
from zero import ZeroServer
-from .schema import User
+from .schema import Student, Teacher, User
app = ZeroServer(port=5559)
@@ -42,6 +42,17 @@ def hello_users(users: typing.List[User]) -> str:
return f"Hello {', '.join([user.name for user in users])}! Your emails are {', '.join([email for user in users for email in user.emails])}!"
+teachers = [
+ Teacher(name="Teacher1"),
+ Teacher(name="Teacher2"),
+]
+
+
+@app.register_rpc
+def hello_student(student: Student) -> str:
+ return f"Hello {student.name}! You are {student.age} years old. Your email is {student.emails[0]}! Your roll no. is {student.roll_no} and your marks are {student.marks}!"
+
+
if __name__ == "__main__":
app.register_rpc(echo)
app.register_rpc(hello_world)
diff --git a/tests/functional/codegen/__init__.py b/tests/functional/codegen/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/functional/codegen/test_codegen.py b/tests/functional/codegen/test_codegen.py
new file mode 100644
index 0000000..df43014
--- /dev/null
+++ b/tests/functional/codegen/test_codegen.py
@@ -0,0 +1,708 @@
+import dataclasses
+import datetime
+import decimal
+import enum
+import typing
+import unittest
+import uuid
+from dataclasses import dataclass
+from datetime import date
+from typing import Dict, List, Optional, Tuple, Union
+
+import msgspec
+from msgspec import Struct
+
+from zero.codegen.codegen import CodeGen
+
+
+@dataclass
+class SimpleDataclass:
+ a: int
+ b: str
+
+
+@dataclasses.dataclass
+class SimpleDataclass2:
+ c: int
+ d: str
+
+
+@dataclass
+class ChildDataclass(SimpleDataclass):
+ e: int
+ f: str
+
+
+class SimpleStruct(Struct):
+ h: int
+ i: str
+
+
+class ComplexStruct(msgspec.Struct):
+ a: int
+ b: str
+ c: SimpleStruct
+ d: List[SimpleStruct]
+ e: Dict[str, SimpleStruct]
+ f: Tuple[SimpleDataclass, SimpleStruct]
+ g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2]
+
+
+class ChildComplexStruct(ComplexStruct):
+ h: int
+ i: str
+
+
+class SimpleEnum(enum.Enum):
+ ONE = 1
+ TWO = 2
+
+
+class SimpleIntEnum(enum.IntEnum):
+ ONE = 1
+ TWO = 2
+
+
+def func_none(arg: None) -> str:
+ return "Received None"
+
+
+def func_bool(arg: bool) -> str:
+ return f"Received bool: {arg}"
+
+
+def func_int(arg: int) -> str:
+ return f"Received int: {arg}"
+
+
+def func_float(arg: float) -> str:
+ return f"Received float: {arg}"
+
+
+def func_str(arg: str) -> str:
+ return f"Received str: {arg}"
+
+
+def func_bytes(arg: bytes) -> str:
+ return f"Received bytes: {arg}"
+
+
+def func_bytearray(arg: bytearray) -> str:
+ return f"Received bytearray: {arg}"
+
+
+def func_tuple(arg: tuple) -> str:
+ return f"Received tuple: {arg}"
+
+
+def func_list(arg: list) -> str:
+ return f"Received list: {arg}"
+
+
+def func_dict(arg: dict) -> str:
+ return f"Received dict: {arg}"
+
+
+def func_optional_dict(arg: Optional[dict]) -> str:
+ return f"Received dict: {arg}"
+
+
+def func_set(arg: set) -> str:
+ return f"Received set: {arg}"
+
+
+def func_frozenset(arg: frozenset) -> str:
+ return f"Received frozenset: {arg}"
+
+
+def func_datetime(arg: datetime.datetime) -> str:
+ return f"Received datetime: {arg}"
+
+
+def func_date(arg: date) -> str:
+ return f"Received date: {arg}"
+
+
+def func_time(arg: datetime.time) -> str:
+ return f"Received time: {arg}"
+
+
+def func_uuid(arg: uuid.UUID) -> str:
+ return f"Received UUID: {arg}"
+
+
+def func_decimal(arg: decimal.Decimal) -> str:
+ return f"Received Decimal: {arg}"
+
+
+def func_enum(arg: SimpleEnum) -> str:
+ return f"Received Enum: {arg}"
+
+
+def func_intenum(arg: SimpleIntEnum) -> str:
+ return f"Received IntEnum: {arg}"
+
+
+def func_dataclass(arg: SimpleDataclass) -> str:
+ return f"Received dataclass: {arg}"
+
+
+def func_tuple_typing(arg: typing.Tuple[int, str]) -> str:
+ return f"Received typing.Tuple: {arg}"
+
+
+def func_list_typing(arg: typing.List[int]) -> str:
+ return f"Received typing.List: {arg}"
+
+
+def func_dict_typing(arg: typing.Dict[str, int]) -> str:
+ return f"Received typing.Dict: {arg}"
+
+
+def func_set_typing(arg: typing.Set[int]) -> str:
+ return f"Received typing.Set: {arg}"
+
+
+def func_frozenset_typing(arg: typing.FrozenSet[int]) -> str:
+ return f"Received typing.FrozenSet: {arg}"
+
+
+def func_any_typing(arg: typing.Any) -> str:
+ return f"Received typing.Any: {arg}"
+
+
+def func_union_typing(arg: typing.Union[int, str]) -> str:
+ return f"Received typing.Union: {arg}"
+
+
+def func_optional_typing(arg: typing.Optional[int]) -> str:
+ return f"Received typing.Optional: {arg}"
+
+
+def func_msgspec_struct(arg: SimpleStruct) -> str:
+ return f"Received msgspec.Struct: {arg}"
+
+
+def func_msgspec_struct_complex(arg: ComplexStruct) -> str:
+ return f"Received msgspec.Struct: {arg}"
+
+
+def func_child_complex_struct(arg: ChildComplexStruct) -> str:
+ return f"Received msgspec.Struct: {arg}"
+
+
+def func_return_optional_child_complex_struct() -> Optional[ChildComplexStruct]:
+ return None
+
+
+def func_return_complex_struct() -> ComplexStruct:
+ return ComplexStruct(
+ a=1,
+ b="hello",
+ c=SimpleStruct(h=1, i="hello"),
+ d=[SimpleStruct(h=1, i="hello")],
+ e={"1": SimpleStruct(h=1, i="hello")},
+ f=(SimpleDataclass(a=1, b="hello"), SimpleStruct(h=1, i="hello")),
+ g=SimpleDataclass(a=1, b="hello"),
+ )
+
+
+def func_take_optional_child_dataclass_return_optional_child_complex_struct(
+ arg: Optional[ChildDataclass],
+) -> Optional[ChildComplexStruct]:
+ return None
+
+
+class TestCodegen(unittest.TestCase):
+ def setUp(self) -> None:
+ self.maxDiff = None
+ self._rpc_router = {
+ "func_none": (func_none, False),
+ "func_bool": (func_bool, False),
+ "func_int": (func_int, False),
+ "func_float": (func_float, False),
+ "func_str": (func_str, False),
+ "func_bytes": (func_bytes, False),
+ "func_bytearray": (func_bytearray, False),
+ "func_tuple": (func_tuple, False),
+ "func_list": (func_list, False),
+ "func_dict": (func_dict, False),
+ "func_optional_dict": (func_optional_dict, False),
+ "func_set": (func_set, False),
+ "func_frozenset": (func_frozenset, False),
+ "func_datetime": (func_datetime, False),
+ "func_date": (func_date, False),
+ "func_time": (func_time, False),
+ "func_uuid": (func_uuid, False),
+ "func_decimal": (func_decimal, False),
+ "func_enum": (func_enum, False),
+ "func_intenum": (func_intenum, False),
+ "func_dataclass": (func_dataclass, False),
+ "func_tuple_typing": (func_tuple_typing, False),
+ "func_list_typing": (func_list_typing, False),
+ "func_dict_typing": (func_dict_typing, False),
+ "func_set_typing": (func_set_typing, False),
+ "func_frozenset_typing": (func_frozenset_typing, False),
+ "func_any_typing": (func_any_typing, False),
+ "func_union_typing": (func_union_typing, False),
+ "func_optional_typing": (func_optional_typing, False),
+ "func_msgspec_struct": (func_msgspec_struct, False),
+ "func_msgspec_struct_complex": (func_msgspec_struct_complex, False),
+ "func_child_complex_struct": (func_child_complex_struct, False),
+ "func_return_complex_struct": (func_return_complex_struct, False),
+ }
+ self._rpc_input_type_map = {
+ "func_none": None,
+ "func_bool": bool,
+ "func_int": int,
+ "func_float": float,
+ "func_str": str,
+ "func_bytes": bytes,
+ "func_bytearray": bytearray,
+ "func_tuple": tuple,
+ "func_list": list,
+ "func_dict": dict,
+ "func_optional_dict": Optional[dict],
+ "func_set": set,
+ "func_frozenset": frozenset,
+ "func_datetime": datetime.datetime,
+ "func_date": datetime.date,
+ "func_time": datetime.time,
+ "func_uuid": uuid.UUID,
+ "func_decimal": decimal.Decimal,
+ "func_enum": SimpleEnum,
+ "func_intenum": SimpleIntEnum,
+ "func_dataclass": SimpleDataclass,
+ "func_tuple_typing": typing.Tuple[int, str],
+ "func_list_typing": typing.List[int],
+ "func_dict_typing": typing.Dict[str, int],
+ "func_set_typing": typing.Set[int],
+ "func_frozenset_typing": typing.FrozenSet[int],
+ "func_any_typing": typing.Any,
+ "func_union_typing": typing.Union[int, str],
+ "func_optional_typing": typing.Optional[int],
+ "func_msgspec_struct": SimpleStruct,
+ "func_msgspec_struct_complex": ComplexStruct,
+ "func_child_complex_struct": ChildComplexStruct,
+ "func_return_complex_struct": None,
+ }
+ self._rpc_return_type_map = {
+ "func_none": str,
+ "func_bool": str,
+ "func_int": str,
+ "func_float": str,
+ "func_str": str,
+ "func_bytes": str,
+ "func_bytearray": str,
+ "func_tuple": str,
+ "func_list": str,
+ "func_dict": str,
+ "func_optional_dict": Optional[str],
+ "func_set": str,
+ "func_frozenset": str,
+ "func_datetime": str,
+ "func_date": str,
+ "func_time": str,
+ "func_uuid": str,
+ "func_decimal": str,
+ "func_enum": str,
+ "func_intenum": str,
+ "func_dataclass": str,
+ "func_tuple_typing": str,
+ "func_list_typing": str,
+ "func_dict_typing": str,
+ "func_set_typing": str,
+ "func_frozenset_typing": str,
+ "func_any_typing": str,
+ "func_union_typing": str,
+ "func_optional_typing": str,
+ "func_msgspec_struct": str,
+ "func_msgspec_struct_complex": str,
+ "func_child_complex_struct": str,
+ "func_return_complex_struct": ComplexStruct,
+ }
+
+ def test_codegen(self):
+ codegen = CodeGen(
+ self._rpc_router, self._rpc_input_type_map, self._rpc_return_type_map
+ )
+ code = codegen.generate_code()
+ expected_code = """# Generated by Zero
+# import types as per needed, not all imports are shown here
+from dataclasses import dataclass
+from datetime import date, datetime, time
+import decimal
+import enum
+import msgspec
+from msgspec import Struct
+from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union
+import uuid
+
+from zero import ZeroClient
+
+
+zero_client = ZeroClient("localhost", 5559)
+
+class SimpleEnum(enum.Enum):
+ ONE = 1
+ TWO = 2
+
+
+class SimpleIntEnum(enum.IntEnum):
+ ONE = 1
+ TWO = 2
+
+
+@dataclass
+class SimpleDataclass:
+ a: int
+ b: str
+
+
+class SimpleStruct(Struct):
+ h: int
+ i: str
+
+
+@dataclass
+class SimpleDataclass2:
+ c: int
+ d: str
+
+
+class ComplexStruct(msgspec.Struct):
+ a: int
+ b: str
+ c: SimpleStruct
+ d: List[SimpleStruct]
+ e: Dict[str, SimpleStruct]
+ f: Tuple[SimpleDataclass, SimpleStruct]
+ g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2]
+
+
+class ChildComplexStruct(ComplexStruct):
+ h: int
+ i: str
+
+
+
+class RpcClient:
+ def __init__(self, zero_client: ZeroClient):
+ self._zero_client = zero_client
+
+ def func_none(selfarg: None) -> str:
+ return self._zero_client.call("func_none", None)
+
+ def func_bool(self, arg: bool) -> str:
+ return self._zero_client.call("func_bool", arg)
+
+ def func_int(self, arg: int) -> str:
+ return self._zero_client.call("func_int", arg)
+
+ def func_float(self, arg: float) -> str:
+ return self._zero_client.call("func_float", arg)
+
+ def func_str(self, arg: str) -> str:
+ return self._zero_client.call("func_str", arg)
+
+ def func_bytes(self, arg: bytes) -> str:
+ return self._zero_client.call("func_bytes", arg)
+
+ def func_bytearray(self, arg: bytearray) -> str:
+ return self._zero_client.call("func_bytearray", arg)
+
+ def func_tuple(self, arg: tuple) -> str:
+ return self._zero_client.call("func_tuple", arg)
+
+ def func_list(self, arg: list) -> str:
+ return self._zero_client.call("func_list", arg)
+
+ def func_dict(self, arg: dict) -> str:
+ return self._zero_client.call("func_dict", arg)
+
+ def func_optional_dict(self, arg: Optional[dict]) -> str:
+ return self._zero_client.call("func_optional_dict", arg)
+
+ def func_set(self, arg: set) -> str:
+ return self._zero_client.call("func_set", arg)
+
+ def func_frozenset(self, arg: frozenset) -> str:
+ return self._zero_client.call("func_frozenset", arg)
+
+ def func_datetime(self, arg: datetime) -> str:
+ return self._zero_client.call("func_datetime", arg)
+
+ def func_date(self, arg: date) -> str:
+ return self._zero_client.call("func_date", arg)
+
+ def func_time(self, arg: time) -> str:
+ return self._zero_client.call("func_time", arg)
+
+ def func_uuid(self, arg: uuid.UUID) -> str:
+ return self._zero_client.call("func_uuid", arg)
+
+ def func_decimal(self, arg: decimal.Decimal) -> str:
+ return self._zero_client.call("func_decimal", arg)
+
+ def func_enum(self, arg: SimpleEnum) -> str:
+ return self._zero_client.call("func_enum", arg)
+
+ def func_intenum(self, arg: SimpleIntEnum) -> str:
+ return self._zero_client.call("func_intenum", arg)
+
+ def func_dataclass(self, arg: SimpleDataclass) -> str:
+ return self._zero_client.call("func_dataclass", arg)
+
+ def func_tuple_typing(self, arg: Tuple[int, str]) -> str:
+ return self._zero_client.call("func_tuple_typing", arg)
+
+ def func_list_typing(self, arg: List[int]) -> str:
+ return self._zero_client.call("func_list_typing", arg)
+
+ def func_dict_typing(self, arg: Dict[str, int]) -> str:
+ return self._zero_client.call("func_dict_typing", arg)
+
+ def func_set_typing(self, arg: Set[int]) -> str:
+ return self._zero_client.call("func_set_typing", arg)
+
+ def func_frozenset_typing(self, arg: FrozenSet[int]) -> str:
+ return self._zero_client.call("func_frozenset_typing", arg)
+
+ def func_any_typing(self, arg: Any) -> str:
+ return self._zero_client.call("func_any_typing", arg)
+
+ def func_union_typing(self, arg: Union[int, str]) -> str:
+ return self._zero_client.call("func_union_typing", arg)
+
+ def func_optional_typing(self, arg: Optional[int]) -> str:
+ return self._zero_client.call("func_optional_typing", arg)
+
+ def func_msgspec_struct(self, arg: SimpleStruct) -> str:
+ return self._zero_client.call("func_msgspec_struct", arg)
+
+ def func_msgspec_struct_complex(self, arg: ComplexStruct) -> str:
+ return self._zero_client.call("func_msgspec_struct_complex", arg)
+
+ def func_child_complex_struct(self, arg: ChildComplexStruct) -> str:
+ return self._zero_client.call("func_child_complex_struct", arg)
+
+ def func_return_complex_struct(self) -> ComplexStruct:
+ return self._zero_client.call("func_return_complex_struct", None)
+"""
+ self.assertEqual(code, expected_code)
+
+ def test_codegen_return_single_complex_struct(self):
+ rpc_router = {
+ "func_return_complex_struct": (func_return_complex_struct, False),
+ }
+ rpc_input_type_map = {
+ "func_return_complex_struct": None,
+ }
+ rpc_return_type_map = {
+ "func_return_complex_struct": ComplexStruct,
+ }
+ codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map)
+ code = codegen.generate_code()
+ expected_code = """# Generated by Zero
+# import types as per needed, not all imports are shown here
+from dataclasses import dataclass
+import enum
+import msgspec
+from msgspec import Struct
+from typing import Dict, List, Tuple, Union
+
+from zero import ZeroClient
+
+
+zero_client = ZeroClient("localhost", 5559)
+
+class SimpleStruct(Struct):
+ h: int
+ i: str
+
+
+@dataclass
+class SimpleDataclass:
+ a: int
+ b: str
+
+
+@dataclass
+class SimpleDataclass2:
+ c: int
+ d: str
+
+
+class ComplexStruct(msgspec.Struct):
+ a: int
+ b: str
+ c: SimpleStruct
+ d: List[SimpleStruct]
+ e: Dict[str, SimpleStruct]
+ f: Tuple[SimpleDataclass, SimpleStruct]
+ g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2]
+
+
+
+class RpcClient:
+ def __init__(self, zero_client: ZeroClient):
+ self._zero_client = zero_client
+
+ def func_return_complex_struct(self) -> ComplexStruct:
+ return self._zero_client.call("func_return_complex_struct", None)
+"""
+ self.assertEqual(code, expected_code)
+
+ def test_codegen_return_optional_complex_struct(self):
+ rpc_router = {
+ "func_return_optional_child_complex_struct": (
+ func_return_optional_child_complex_struct,
+ False,
+ ),
+ }
+ rpc_input_type_map = {
+ "func_return_optional_child_complex_struct": None,
+ }
+ rpc_return_type_map = {
+ "func_return_optional_child_complex_struct": Optional[ChildComplexStruct],
+ }
+ codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map)
+ code = codegen.generate_code()
+ expected_code = """# Generated by Zero
+# import types as per needed, not all imports are shown here
+from dataclasses import dataclass
+import enum
+import msgspec
+from msgspec import Struct
+from typing import Dict, List, Optional, Tuple, Union
+
+from zero import ZeroClient
+
+
+zero_client = ZeroClient("localhost", 5559)
+
+class SimpleStruct(Struct):
+ h: int
+ i: str
+
+
+@dataclass
+class SimpleDataclass:
+ a: int
+ b: str
+
+
+@dataclass
+class SimpleDataclass2:
+ c: int
+ d: str
+
+
+class ComplexStruct(msgspec.Struct):
+ a: int
+ b: str
+ c: SimpleStruct
+ d: List[SimpleStruct]
+ e: Dict[str, SimpleStruct]
+ f: Tuple[SimpleDataclass, SimpleStruct]
+ g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2]
+
+
+class ChildComplexStruct(ComplexStruct):
+ h: int
+ i: str
+
+
+
+class RpcClient:
+ def __init__(self, zero_client: ZeroClient):
+ self._zero_client = zero_client
+
+ def func_return_optional_child_complex_struct(self) -> Optional[ChildComplexStruct]:
+ return self._zero_client.call("func_return_optional_child_complex_struct", None)
+"""
+ self.assertEqual(code, expected_code)
+
+ def test_codegen_optional_child_dataclass_return_optional_child_complex_struct(
+ self,
+ ):
+ rpc_router = {
+ "func_take_optional_child_dataclass_return_optional_child_complex_struct": (
+ func_take_optional_child_dataclass_return_optional_child_complex_struct,
+ False,
+ ),
+ }
+ rpc_input_type_map = {
+ "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[
+ ChildDataclass
+ ],
+ }
+ rpc_return_type_map = {
+ "func_take_optional_child_dataclass_return_optional_child_complex_struct": Optional[
+ ChildComplexStruct
+ ],
+ }
+ codegen = CodeGen(rpc_router, rpc_input_type_map, rpc_return_type_map)
+ code = codegen.generate_code()
+ expected_code = """# Generated by Zero
+# import types as per needed, not all imports are shown here
+from dataclasses import dataclass
+import enum
+import msgspec
+from msgspec import Struct
+from typing import Dict, List, Optional, Tuple, Union
+
+from zero import ZeroClient
+
+
+zero_client = ZeroClient("localhost", 5559)
+
+@dataclass
+class SimpleDataclass:
+ a: int
+ b: str
+
+
+@dataclass
+class ChildDataclass(SimpleDataclass):
+ e: int
+ f: str
+
+
+class SimpleStruct(Struct):
+ h: int
+ i: str
+
+
+@dataclass
+class SimpleDataclass2:
+ c: int
+ d: str
+
+
+class ComplexStruct(msgspec.Struct):
+ a: int
+ b: str
+ c: SimpleStruct
+ d: List[SimpleStruct]
+ e: Dict[str, SimpleStruct]
+ f: Tuple[SimpleDataclass, SimpleStruct]
+ g: Union[SimpleStruct, SimpleDataclass, SimpleDataclass2]
+
+
+class ChildComplexStruct(ComplexStruct):
+ h: int
+ i: str
+
+
+
+class RpcClient:
+ def __init__(self, zero_client: ZeroClient):
+ self._zero_client = zero_client
+
+ def func_take_optional_child_dataclass_return_optional_child_complex_struct(self,
+ arg: Optional[ChildDataclass],
+) -> Optional[ChildComplexStruct]:
+ return self._zero_client.call("func_take_optional_child_dataclass_return_optional_child_complex_struct", arg)
+"""
+ self.assertEqual(code, expected_code)
diff --git a/tests/functional/single_server/client_generation_test.py b/tests/functional/single_server/client_generation_test.py
index 22d3fc7..da1b33c 100644
--- a/tests/functional/single_server/client_generation_test.py
+++ b/tests/functional/single_server/client_generation_test.py
@@ -17,18 +17,129 @@ def test_codegeneration():
assert (
code
== """# Generated by Zero
-# import types as per needed
+# import types as per needed, not all imports are shown here
+from dataclasses import dataclass
+from datetime import date, datetime, time
+import decimal
+import enum
+import msgspec
+from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union
+import uuid
from zero import ZeroClient
zero_client = ZeroClient("localhost", 5559)
+class Color(enum.Enum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+class ColorInt(enum.IntEnum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+@dataclass
+class Dataclass:
+ name: str
+ age: int
+
+
+class Message(msgspec.Struct):
+ msg: str
+ start_time: datetime
+
+
class RpcClient:
def __init__(self, zero_client: ZeroClient):
self._zero_client = zero_client
+ def echo_bool(self, msg: bool) -> bool:
+ return self._zero_client.call("echo_bool", msg)
+
+ def echo_int(self, msg: int) -> int:
+ return self._zero_client.call("echo_int", msg)
+
+ def echo_float(self, msg: float) -> float:
+ return self._zero_client.call("echo_float", msg)
+
+ def echo_str(self, msg: str) -> str:
+ return self._zero_client.call("echo_str", msg)
+
+ def echo_bytes(self, msg: bytes) -> bytes:
+ return self._zero_client.call("echo_bytes", msg)
+
+ def echo_bytearray(self, msg: bytearray) -> bytearray:
+ return self._zero_client.call("echo_bytearray", msg)
+
+ def echo_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]:
+ return self._zero_client.call("echo_tuple", msg)
+
+ def echo_list(self, msg: List[int]) -> List[int]:
+ return self._zero_client.call("echo_list", msg)
+
+ def echo_dict(self, msg: Dict[int, str]) -> Dict[int, str]:
+ return self._zero_client.call("echo_dict", msg)
+
+ def echo_set(self, msg: Set[int]) -> Set[int]:
+ return self._zero_client.call("echo_set", msg)
+
+ def echo_frozenset(self, msg: FrozenSet[int]) -> FrozenSet[int]:
+ return self._zero_client.call("echo_frozenset", msg)
+
+ def echo_datetime(self, msg: datetime) -> datetime:
+ return self._zero_client.call("echo_datetime", msg)
+
+ def echo_date(self, msg: date) -> date:
+ return self._zero_client.call("echo_date", msg)
+
+ def echo_time(self, msg: time) -> time:
+ return self._zero_client.call("echo_time", msg)
+
+ def echo_uuid(self, msg: uuid.UUID) -> uuid.UUID:
+ return self._zero_client.call("echo_uuid", msg)
+
+ def echo_decimal(self, msg: decimal.Decimal) -> decimal.Decimal:
+ return self._zero_client.call("echo_decimal", msg)
+
+ def echo_enum(self, msg: Color) -> Color:
+ return self._zero_client.call("echo_enum", msg)
+
+ def echo_enum_int(self, msg: ColorInt) -> ColorInt:
+ return self._zero_client.call("echo_enum_int", msg)
+
+ def echo_dataclass(self, msg: Dataclass) -> Dataclass:
+ return self._zero_client.call("echo_dataclass", msg)
+
+ def echo_typing_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]:
+ return self._zero_client.call("echo_typing_tuple", msg)
+
+ def echo_typing_list(self, msg: List[int]) -> List[int]:
+ return self._zero_client.call("echo_typing_list", msg)
+
+ def echo_typing_dict(self, msg: Dict[int, str]) -> Dict[int, str]:
+ return self._zero_client.call("echo_typing_dict", msg)
+
+ def echo_typing_set(self, msg: Set[int]) -> Set[int]:
+ return self._zero_client.call("echo_typing_set", msg)
+
+ def echo_typing_frozenset(self, msg: FrozenSet[int]) -> FrozenSet[int]:
+ return self._zero_client.call("echo_typing_frozenset", msg)
+
+ def echo_typing_union(self, msg: Union[int, str]) -> Union[int, str]:
+ return self._zero_client.call("echo_typing_union", msg)
+
+ def echo_typing_optional(self, msg: Optional[int]) -> int:
+ return self._zero_client.call("echo_typing_optional", msg)
+
+ def echo_msgspec_struct(self, msg: Message) -> Message:
+ return self._zero_client.call("echo_msgspec_struct", msg)
+
def sleep(self, msec: int) -> str:
return self._zero_client.call("sleep", msec)
@@ -38,9 +149,6 @@ def sleep_async(self, msec: int) -> str:
def error(self, msg: str) -> str:
return self._zero_client.call("error", msg)
- def msgspec_struct(self, start: datetime.datetime) -> Message:
- return self._zero_client.call("msgspec_struct", start)
-
def send_bytes(self, msg: bytes) -> bytes:
return self._zero_client.call("send_bytes", msg)
@@ -53,19 +161,10 @@ def hello_world(self) -> str:
def decode_jwt(self, msg: str) -> str:
return self._zero_client.call("decode_jwt", msg)
- def sum_list(self, msg: typing.List[int]) -> int:
+ def sum_list(self, msg: List[int]) -> int:
return self._zero_client.call("sum_list", msg)
- def echo_dict(self, msg: typing.Dict[int, str]) -> typing.Dict[int, str]:
- return self._zero_client.call("echo_dict", msg)
-
- def echo_tuple(self, msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]:
- return self._zero_client.call("echo_tuple", msg)
-
- def echo_union(self, msg: typing.Union[int, str]) -> typing.Union[int, str]:
- return self._zero_client.call("echo_union", msg)
-
- def divide(self, msg: typing.Tuple[int, int]) -> int:
+ def divide(self, msg: Tuple[int, int]) -> int:
return self._zero_client.call("divide", msg)
"""
)
diff --git a/tests/functional/single_server/client_server_test.py b/tests/functional/single_server/client_server_test.py
index 6d79246..b29a809 100644
--- a/tests/functional/single_server/client_server_test.py
+++ b/tests/functional/single_server/client_server_test.py
@@ -1,4 +1,7 @@
import datetime
+import decimal
+import typing
+import uuid
import pytest
import requests
@@ -11,14 +14,193 @@
from .server import Message
-def test_hello_world():
- zero_client = ZeroClient(server.HOST, server.PORT)
+@pytest.fixture
+def zero_client():
+ return ZeroClient(server.HOST, server.PORT)
+
+
+# bool input
+def test_echo_bool(zero_client):
+ assert zero_client.call("echo_bool", True) is True
+
+
+# int input
+def test_echo_int(zero_client):
+ assert zero_client.call("echo_int", 42) == 42
+
+
+# float input
+def test_echo_float(zero_client):
+ assert zero_client.call("echo_float", 3.14) == 3.14
+
+
+# str input
+def test_echo_str(zero_client):
+ assert zero_client.call("echo_str", "hello") == "hello"
+
+
+# bytes input
+def test_echo_bytes(zero_client):
+ assert zero_client.call("echo_bytes", b"hello") == b"hello"
+
+
+# bytearray input
+def test_echo_bytearray(zero_client):
+ assert zero_client.call("echo_bytearray", bytearray(b"hello")) == bytearray(
+ b"hello"
+ )
+
+
+# tuple input
+def test_echo_tuple(zero_client):
+ assert zero_client.call("echo_tuple", (1, "a"), return_type=tuple) == (1, "a")
+
+
+# list input
+def test_echo_list(zero_client):
+ assert zero_client.call("echo_list", [1, 2, 3]) == [1, 2, 3]
+
+
+# dict input
+def test_echo_dict(zero_client):
+ assert zero_client.call("echo_dict", {1: "a"}) == {1: "a"}
+
+
+# set input
+def test_echo_set(zero_client):
+ assert zero_client.call("echo_set", {1, 2, 3}, return_type=set) == {1, 2, 3}
+
+
+# frozenset input
+def test_echo_frozenset(zero_client):
+ assert zero_client.call(
+ "echo_frozenset", frozenset({1, 2, 3}), return_type=frozenset
+ ) == frozenset({1, 2, 3})
+
+
+# datetime input
+def test_echo_datetime(zero_client):
+ now = datetime.datetime.now()
+ assert zero_client.call("echo_datetime", now, return_type=datetime.datetime) == now
+
+
+# date input
+def test_echo_date(zero_client):
+ today = datetime.date.today()
+ assert zero_client.call("echo_date", today, return_type=datetime.date) == today
+
+
+# time input
+def test_echo_time(zero_client):
+ now = datetime.datetime.now().time()
+ assert zero_client.call("echo_time", now, return_type=datetime.time) == now
+
+
+# uuid input
+def test_echo_uuid(zero_client):
+ uid = uuid.uuid4()
+ assert zero_client.call("echo_uuid", uid, return_type=uuid.UUID) == uid
+
+
+# decimal input
+def test_echo_decimal(zero_client):
+ value = decimal.Decimal("10.1")
+ assert zero_client.call("echo_decimal", value, return_type=decimal.Decimal) == value
+
+
+# enum input
+def test_echo_enum(zero_client):
+ assert (
+ zero_client.call("echo_enum", server.Color.RED, return_type=server.Color)
+ == server.Color.RED
+ )
+
+
+# enum int input
+def test_echo_enum_int(zero_client):
+ assert (
+ zero_client.call("echo_enum_int", server.ColorInt.GREEN)
+ == server.ColorInt.GREEN
+ )
+
+
+# dataclass input
+def test_echo_dataclass(zero_client):
+ data = server.Dataclass(name="John", age=30)
+ result = zero_client.call("echo_dataclass", data, return_type=server.Dataclass)
+ assert result == data
+
+
+# typing.Tuple input
+def test_echo_typing_tuple(zero_client):
+ assert zero_client.call(
+ "echo_typing_tuple", (1, "a"), return_type=typing.Tuple
+ ) == (1, "a")
+
+
+# typing.List input
+def test_echo_typing_list(zero_client):
+ assert zero_client.call("echo_typing_list", [1, 2, 3]) == [1, 2, 3]
+
+
+# typing.Dict input
+def test_echo_typing_dict(zero_client):
+ assert zero_client.call("echo_typing_dict", {1: "a"}, return_type=typing.Dict) == {
+ 1: "a"
+ }
+
+
+# typing.Set input
+def test_echo_typing_set(zero_client):
+ assert zero_client.call("echo_typing_set", {1, 2, 3}, return_type=typing.Set) == {
+ 1,
+ 2,
+ 3,
+ }
+
+
+# typing.FrozenSet input
+def test_echo_typing_frozenset(zero_client):
+ assert zero_client.call(
+ "echo_typing_frozenset", frozenset({1, 2, 3}), return_type=typing.FrozenSet
+ ) == frozenset({1, 2, 3})
+
+
+# typing.Union input
+def test_echo_typing_union(zero_client):
+ assert (
+ zero_client.call("echo_typing_union", 1, return_type=typing.Union[str, int])
+ == 1
+ )
+ assert (
+ zero_client.call("echo_typing_union", "a", return_type=typing.Union[str, int])
+ == "a"
+ )
+
+
+# typing.Optional input
+def test_echo_typing_optional(zero_client):
+ assert zero_client.call("echo_typing_optional", None) == 0
+ assert (
+ zero_client.call("echo_typing_optional", 1, return_type=typing.Optional[int])
+ == 1
+ )
+
+
+# msgspec.Struct input
+def test_echo_msgspec_struct(zero_client):
+ msg = server.Message(msg="hello world", start_time=datetime.datetime.now())
+ result = zero_client.call("echo_msgspec_struct", msg, return_type=server.Message)
+ assert result.msg == msg.msg
+ assert result.start_time == msg.start_time
+
+
+def test_hello_world(zero_client):
msg = zero_client.call("hello_world", "")
assert msg == "hello world"
-def test_necho():
- zero_client = ZeroClient(server.HOST, server.PORT)
+def test_necho(zero_client):
with pytest.raises(zero.error.MethodNotFoundException):
msg = zero_client.call("necho", "hello")
assert msg is None
@@ -31,20 +213,12 @@ def test_echo_wrong_port():
assert msg is None
-def test_sum_list():
- zero_client = ZeroClient(server.HOST, server.PORT)
+def test_sum_list(zero_client):
msg = zero_client.call("sum_list", [1, 2, 3])
assert msg == 6
-def test_echo_dict():
- zero_client = ZeroClient(server.HOST, server.PORT)
- msg = zero_client.call("echo_dict", {1: "b"})
- assert msg == {1: "b"}
-
-
-def test_echo_dict_validation_error():
- zero_client = ZeroClient(server.HOST, server.PORT)
+def test_echo_dict_validation_error(zero_client):
with pytest.raises(ValidationException):
msg = zero_client.call("echo_dict", {"a": "b"})
assert msg == {
@@ -52,16 +226,14 @@ def test_echo_dict_validation_error():
}
-def test_echo_tuple():
- zero_client = ZeroClient(server.HOST, server.PORT)
+def test_echo_tuple_2(zero_client):
msg = zero_client.call("echo_tuple", (1, "a"))
assert isinstance(msg, list) # IMPORTANT
assert msg == [1, "a"]
-def test_echo_union():
- zero_client = ZeroClient(server.HOST, server.PORT)
- msg = zero_client.call("echo_union", 1)
+def test_echo_union(zero_client):
+ msg = zero_client.call("echo_typing_union", 1)
assert msg == 1
@@ -106,11 +278,11 @@ class Example:
def test_msgspec_struct():
- now = datetime.datetime.now()
+ msg = Message("hello world", datetime.datetime.now())
zero_client = ZeroClient(server.HOST, server.PORT)
- msg = zero_client.call("msgspec_struct", now, return_type=Message)
+ msg = zero_client.call("echo_msgspec_struct", msg, return_type=Message)
assert msg.msg == "hello world"
- assert msg.start_time == now
+ assert msg.start_time == msg.start_time
def test_send_bytes():
@@ -121,18 +293,18 @@ def test_send_bytes():
def test_send_http_request():
with pytest.raises(requests.exceptions.ReadTimeout):
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
def test_server_works_after_multiple_http_requests():
"""Because of this issue https://github.com/Ananto30/zero/issues/41"""
try:
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
- requests.get(f"http://{server.HOST}:{server.PORT}", timeout=2)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
+ requests.get(f"http://{server.HOST}:{server.PORT}", timeout=0.1)
except requests.exceptions.ReadTimeout:
pass
zero_client = ZeroClient(server.HOST, server.PORT)
diff --git a/tests/functional/single_server/server.py b/tests/functional/single_server/server.py
index 7d4d278..f6940c2 100644
--- a/tests/functional/single_server/server.py
+++ b/tests/functional/single_server/server.py
@@ -1,7 +1,11 @@
import asyncio
import datetime
+import decimal
+import enum
import time
import typing
+import uuid
+from dataclasses import dataclass
import jwt
import msgspec
@@ -14,36 +18,210 @@
app = ZeroServer(port=PORT)
-async def echo(msg: str) -> str:
+# None input
+async def hello_world() -> str:
+ return "hello world"
+
+
+# bool input
+@app.register_rpc
+def echo_bool(msg: bool) -> bool:
return msg
-async def hello_world() -> str:
- return "hello world"
+# int input
+@app.register_rpc
+def echo_int(msg: int) -> int:
+ return msg
-async def decode_jwt(msg: str) -> str:
- encoded_jwt = jwt.encode(msg, "secret", algorithm="HS256") # type: ignore
- decoded_jwt = jwt.decode(encoded_jwt, "secret", algorithms=["HS256"])
- return decoded_jwt # type: ignore
+# float input
+@app.register_rpc
+def echo_float(msg: float) -> float:
+ return msg
-def sum_list(msg: typing.List[int]) -> int:
- return sum(msg)
+# str input
+@app.register_rpc
+def echo_str(msg: str) -> str:
+ return msg
-def echo_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]:
+# bytes input
+@app.register_rpc
+def echo_bytes(msg: bytes) -> bytes:
return msg
+# bytearray input
+@app.register_rpc
+def echo_bytearray(msg: bytearray) -> bytearray:
+ return msg
+
+
+# tuple input
+@app.register_rpc
def echo_tuple(msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]:
return msg
-def echo_union(msg: typing.Union[int, str]) -> typing.Union[int, str]:
+# list input
+@app.register_rpc
+def echo_list(msg: typing.List[int]) -> typing.List[int]:
+ return msg
+
+
+# dict input
+@app.register_rpc
+def echo_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]:
return msg
+# set input
+@app.register_rpc
+def echo_set(msg: typing.Set[int]) -> typing.Set[int]:
+ return msg
+
+
+# frozenset input
+@app.register_rpc
+def echo_frozenset(msg: typing.FrozenSet[int]) -> typing.FrozenSet[int]:
+ return msg
+
+
+# datetime input
+@app.register_rpc
+def echo_datetime(msg: datetime.datetime) -> datetime.datetime:
+ return msg
+
+
+# date input
+@app.register_rpc
+def echo_date(msg: datetime.date) -> datetime.date:
+ return msg
+
+
+# time input
+@app.register_rpc
+def echo_time(msg: datetime.time) -> datetime.time:
+ return msg
+
+
+# uuid input
+@app.register_rpc
+def echo_uuid(msg: uuid.UUID) -> uuid.UUID:
+ return msg
+
+
+# decimal input
+@app.register_rpc
+def echo_decimal(msg: decimal.Decimal) -> decimal.Decimal:
+ return msg
+
+
+# enum input
+class Color(enum.Enum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+@app.register_rpc
+def echo_enum(msg: Color) -> Color:
+ return msg
+
+
+# enum int input
+class ColorInt(enum.IntEnum):
+ RED = 1
+ GREEN = 2
+ BLUE = 3
+
+
+@app.register_rpc
+def echo_enum_int(msg: ColorInt) -> ColorInt:
+ return msg
+
+
+# dataclass input
+@dataclass
+class Dataclass:
+ name: str
+ age: int
+
+
+@app.register_rpc
+def echo_dataclass(msg: Dataclass) -> Dataclass:
+ return msg
+
+
+# typing.Tuple input
+@app.register_rpc
+def echo_typing_tuple(msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]:
+ return msg
+
+
+# typing.List input
+@app.register_rpc
+def echo_typing_list(msg: typing.List[int]) -> typing.List[int]:
+ return msg
+
+
+# typing.Dict input
+@app.register_rpc
+def echo_typing_dict(msg: typing.Dict[int, str]) -> typing.Dict[int, str]:
+ return msg
+
+
+# typing.Set input
+@app.register_rpc
+def echo_typing_set(msg: typing.Set[int]) -> typing.Set[int]:
+ return msg
+
+
+# typing.FrozenSet input
+@app.register_rpc
+def echo_typing_frozenset(msg: typing.FrozenSet[int]) -> typing.FrozenSet[int]:
+ return msg
+
+
+# typing.Union input
+@app.register_rpc
+def echo_typing_union(msg: typing.Union[int, str]) -> typing.Union[int, str]:
+ return msg
+
+
+# typing.Optional input
+@app.register_rpc
+def echo_typing_optional(msg: typing.Optional[int]) -> int:
+ return msg or 0
+
+
+# msgspec.Struct input
+class Message(msgspec.Struct):
+ msg: str
+ start_time: datetime.datetime
+
+
+@app.register_rpc
+def echo_msgspec_struct(msg: Message) -> Message:
+ return msg
+
+
+async def echo(msg: str) -> str:
+ return msg
+
+
+async def decode_jwt(msg: str) -> str:
+ encoded_jwt = jwt.encode(msg, "secret", algorithm="HS256") # type: ignore
+ decoded_jwt = jwt.decode(encoded_jwt, "secret", algorithms=["HS256"])
+ return decoded_jwt # type: ignore
+
+
+def sum_list(msg: typing.List[int]) -> int:
+ return sum(msg)
+
+
def divide(msg: typing.Tuple[int, int]) -> int:
return int(msg[0] / msg[1])
@@ -69,16 +247,6 @@ def error(msg: str) -> str:
raise RuntimeError(msg)
-class Message(msgspec.Struct):
- msg: str
- start_time: datetime.datetime
-
-
-@app.register_rpc
-def msgspec_struct(start: datetime.datetime) -> Message:
- return Message(msg="hello world", start_time=start)
-
-
@app.register_rpc
def send_bytes(msg: bytes) -> bytes:
return msg
@@ -90,9 +258,6 @@ def run(port):
app.register_rpc(hello_world)
app.register_rpc(decode_jwt)
app.register_rpc(sum_list)
- app.register_rpc(echo_dict)
- app.register_rpc(echo_tuple)
- app.register_rpc(echo_union)
app.register_rpc(divide)
app.run(2)
diff --git a/tests/functional/test_async_to_sync.py b/tests/functional/test_async_to_sync.py
new file mode 100644
index 0000000..de585aa
--- /dev/null
+++ b/tests/functional/test_async_to_sync.py
@@ -0,0 +1,51 @@
+import asyncio
+
+import pytest
+
+from zero.utils.async_to_sync import async_to_sync
+
+
+# Test case 1: Test a simple async function
+async def simple_async_function(x):
+ await asyncio.sleep(0.1) # Simulate async work
+ return x * 2
+
+
+def test_simple_async_function():
+ sync_function = async_to_sync(simple_async_function)
+ result = sync_function(5)
+ assert result == 10, "The async function should return 10 when called with 5"
+
+
+# Test case 2: Test an async function that raises an exception
+async def async_function_raises_exception():
+ raise ValueError("This is a test exception")
+
+
+def test_async_function_exception():
+ sync_function = async_to_sync(async_function_raises_exception)
+ with pytest.raises(ValueError) as exc_info:
+ sync_function()
+ assert (
+ str(exc_info.value) == "This is a test exception"
+ ), "The exception message should be 'This is a test exception'"
+
+
+# Test case 3: Test the reusability of async_to_sync for multiple functions
+async def another_simple_async_function(x):
+ await asyncio.sleep(0.1) # Simulate async work
+ return x + 100
+
+
+def test_reusability_of_async_to_sync():
+ sync_function_1 = async_to_sync(simple_async_function)
+ result_1 = sync_function_1(5)
+ assert (
+ result_1 == 10
+ ), "The first async function should return 10 when called with 5"
+
+ sync_function_2 = async_to_sync(another_simple_async_function)
+ result_2 = sync_function_2(5)
+ assert (
+ result_2 == 105
+ ), "The second async function should return 105 when called with 5"
diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py
index 7790820..c910818 100644
--- a/tests/unit/test_server.py
+++ b/tests/unit/test_server.py
@@ -159,6 +159,19 @@ def decode_type(self, message: bytes, typ: Any) -> Any:
self.assertEqual(server._rpc_input_type_map, {})
self.assertEqual(server._rpc_return_type_map, {})
+ def test_create_server_with_invalid_encoder(self):
+ with self.assertRaises(TypeError):
+ ZeroServer(encoder="encoder")
+
+ def test_create_server_with_invalid_protocol(self):
+ with self.assertRaises(ValueError):
+ ZeroServer(protocol="invalid_protocol")
+
+ def test_create_server_with_protocol_with_no_server(self):
+ with patch("zero.rpc.server.config.SUPPORTED_PROTOCOLS", {"redis": {}}):
+ with self.assertRaises(ValueError):
+ ZeroServer(protocol="redis")
+
def test_register_rpc(self):
server = ZeroServer()
@@ -248,16 +261,17 @@ def add(msg: Tuple[int, int]) -> int:
server._broker.backend, # type: ignore
)
- # TODO fix
- # # @pytest.mark.skipif(sys.platform == "win32", reason="Does not run on windows")
- # # @pytest.mark.skip
- # def test_server_run_keyboard_interrupt(self):
- # server = ZeroServer()
+ def test_server_run_keyboard_interrupt(self):
+ server = ZeroServer()
- # @server.register_rpc
- # def add(msg: Tuple[int, int]) -> int:
- # return msg[0] + msg[1]
+ @server.register_rpc
+ def add(msg: Tuple[int, int]) -> int:
+ return msg[0] + msg[1]
- # with patch.object(server, "_start_server", side_effect=KeyboardInterrupt):
- # with self.assertRaises(SystemExit):
- # server.run()
+ with patch.object(server, "_server_inst") as mock_server_inst:
+ mock_server_inst.start.side_effect = KeyboardInterrupt
+ with patch("logging.warning") as mock_warning:
+ server.run()
+ mock_warning.assert_called_with(
+ "Caught KeyboardInterrupt, terminating server"
+ )
diff --git a/tests/unit/test_type_util.py b/tests/unit/test_type_util.py
new file mode 100644
index 0000000..813fe0b
--- /dev/null
+++ b/tests/unit/test_type_util.py
@@ -0,0 +1,134 @@
+import unittest
+from typing import Optional
+from unittest.mock import MagicMock
+
+from zero.utils.type_util import (
+ get_function_input_class,
+ get_function_return_class,
+ verify_function_args,
+ verify_function_input_type,
+ verify_function_return,
+ verify_function_return_type,
+)
+
+
+class TestVerifyFunctionReturnType(unittest.TestCase):
+ def test_valid_return_type(self):
+ def func() -> int:
+ return 1
+
+ verify_function_return_type(func)
+
+ def test_none_return_type(self):
+ def func() -> None:
+ return None
+
+ with self.assertRaises(TypeError):
+ verify_function_return_type(func)
+
+ def test_optional_return_type(self):
+ def func() -> Optional[int]:
+ return None
+
+ with self.assertRaises(TypeError):
+ verify_function_return_type(func)
+
+ def test_invalid_return_type(self):
+ class CustomType:
+ pass
+
+ def func() -> CustomType:
+ return CustomType()
+
+ with self.assertRaises(TypeError):
+ verify_function_return_type(func)
+
+ def test_mocked_return_type(self):
+ def func() -> MagicMock:
+ return MagicMock()
+
+ with self.assertRaises(TypeError):
+ verify_function_return_type(func)
+
+ def test__verify_function_args__ok(self):
+ def func(a: int) -> int:
+ return a
+
+ verify_function_args(func)
+
+ def test__verify_function_args__multiple_args(self):
+ def func(a: int, b: int) -> int:
+ return a + b
+
+ with self.assertRaises(ValueError):
+ verify_function_args(func)
+
+ def test__verify_function_args__no_type_hint(self):
+ def func(a):
+ return a
+
+ with self.assertRaises(TypeError):
+ verify_function_args(func)
+
+ def test__verify_function_return__ok(self):
+ def func() -> int:
+ return 1
+
+ verify_function_return(func)
+
+ def test__verify_function_return__no_type_hint(self):
+ def func():
+ return 1
+
+ with self.assertRaises(TypeError):
+ verify_function_return(func)
+
+ def test__get_function_input_class__ok(self):
+ def func(a: int) -> int:
+ return a
+
+ self.assertEqual(get_function_input_class(func), int)
+
+ def test__get_function_input_class__no_args(self):
+ def func() -> int:
+ return 1
+
+ self.assertEqual(get_function_input_class(func), None)
+
+ def test__get_function_input_class__multiple_args(self):
+ def func(a: int, b: int) -> int:
+ return a + b
+
+ self.assertEqual(get_function_input_class(func), None)
+
+ def test__get_function_return_class__ok(self):
+ def func() -> int:
+ return 1
+
+ self.assertEqual(get_function_return_class(func), int)
+
+ def test__get_function_return_class__no_return(self):
+ def func():
+ return 1
+
+ self.assertEqual(get_function_return_class(func), None)
+
+ def test__verify_function_input_type__ok(self):
+ def func(a: int) -> int:
+ return a
+
+ verify_function_input_type(func)
+
+ def test__verify_function_input_type__invalid(self):
+ def func(a: MagicMock) -> int:
+ return a
+
+ with self.assertRaises(TypeError):
+ verify_function_input_type(func)
+
+ def test__verify_function_input_type__no_type_hint(self):
+ def func(a) -> int:
+ return a
+
+ with self.assertRaises(KeyError):
+ verify_function_input_type(func)
diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py
new file mode 100644
index 0000000..125eaa3
--- /dev/null
+++ b/tests/unit/test_util.py
@@ -0,0 +1,21 @@
+import logging
+import unittest
+from unittest.mock import patch
+
+from zero.utils.util import log_error
+
+
+class TestLogError(unittest.TestCase):
+ def test_log_error(self):
+ @log_error
+ def divide(a, b):
+ return a / b
+
+ with patch.object(logging, "exception") as mock_exception:
+ result = divide(10, 2)
+ self.assertEqual(result, 5)
+ mock_exception.assert_not_called()
+
+ result = divide(10, 0)
+ self.assertIsNone(result)
+ mock_exception.assert_called_once()
diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py
index efa0da4..0ce97d9 100644
--- a/tests/unit/test_worker.py
+++ b/tests/unit/test_worker.py
@@ -1,6 +1,10 @@
import unittest
from unittest.mock import MagicMock, Mock, patch
+import msgspec
+
+from zero.encoder.protocols import Encoder
+from zero.error import SERVER_PROCESSING_ERROR
from zero.protocols.zeromq.worker import _Worker
@@ -55,6 +59,26 @@ def test_start_dealer_worker_exception_handling(self, mock_get_worker):
self.assertIn("Test Exception", log.output[0])
mock_worker.close.assert_called_once()
+ @patch("zero.protocols.zeromq.worker.get_worker")
+ def test_start_dealer_worker_keyboard_interrupt_handling(self, mock_get_worker):
+ mock_worker = Mock()
+ mock_get_worker.return_value = mock_worker
+ mock_worker.listen.side_effect = KeyboardInterrupt
+
+ worker_id = 1
+ worker = _Worker(
+ self.rpc_router,
+ self.device_comm_channel,
+ self.encoder,
+ self.rpc_input_type_map,
+ self.rpc_return_type_map,
+ )
+
+ with self.assertLogs(level="WARNING") as log:
+ worker.start_dealer_worker(worker_id)
+ self.assertIn("terminating worker", log.output[0])
+ mock_worker.close.assert_called_once()
+
@patch("zero.protocols.zeromq.worker.async_to_sync", side_effect=lambda x: x)
def test_handle_msg_get_rpc_contract(self, mock_async_to_sync):
worker = _Worker(
@@ -70,7 +94,7 @@ def test_handle_msg_get_rpc_contract(self, mock_async_to_sync):
with patch.object(
worker, "generate_rpc_contract", return_value=expected_response
) as mock_generate_rpc_contract:
- response = worker.handle_msg("get_rpc_contract", msg)
+ response = worker.execute_rpc("get_rpc_contract", msg)
mock_generate_rpc_contract.assert_called_once_with(msg)
self.assertEqual(response, expected_response)
@@ -89,7 +113,7 @@ def test_handle_msg_rpc_call_exception(self, mock_async_to_sync):
self.rpc_return_type_map,
)
- response = worker.handle_msg("failing_function", "msg")
+ response = worker.execute_rpc("failing_function", "msg")
self.assertEqual(
response, {"__zerror__server_exception": "Exception('RPC Exception')"}
)
@@ -105,7 +129,7 @@ def test_handle_msg_connect(self):
msg = "some_message"
expected_response = "connected"
- response = worker.handle_msg("connect", msg)
+ response = worker.execute_rpc("connect", msg)
self.assertEqual(response, expected_response)
@@ -122,7 +146,7 @@ def test_handle_msg_function_not_found(self):
"__zerror__function_not_found": "Function `some_function_not_found` not found!"
}
- response = worker.handle_msg("some_function_not_found", msg)
+ response = worker.execute_rpc("some_function_not_found", msg)
self.assertEqual(response, expected_response)
@@ -143,7 +167,7 @@ def test_handle_msg_server_exception(self):
"zero.protocols.zeromq.worker.async_to_sync",
side_effect=Exception("Exception occurred"),
):
- response = worker.handle_msg("some_function", msg)
+ response = worker.execute_rpc("some_function", msg)
self.assertEqual(response, expected_response)
@@ -219,3 +243,99 @@ def test_spawn_worker(self):
rpc_return_type_map,
)
mock_worker.start_dealer_worker.assert_called_once_with(worker_id)
+
+
+def some_function(msg: str) -> str:
+ return msg
+
+
+class TestWorkerHandleMsg(unittest.TestCase):
+ def setUp(self):
+ self.rpc_router = {
+ "get_rpc_contract": (MagicMock(), False),
+ "connect": (MagicMock(), False),
+ "some_function": (some_function, False),
+ }
+ self.device_comm_channel = "tcp://example.com:5555"
+ self.encoder = MagicMock(spec=Encoder)
+ self.rpc_input_type_map = {
+ "some_function": str,
+ }
+ self.rpc_return_type_map = {
+ "some_function": str,
+ }
+
+ def test_handle_msg_with_valid_input(self):
+ worker = _Worker(
+ self.rpc_router,
+ self.device_comm_channel,
+ self.encoder,
+ self.rpc_input_type_map,
+ self.rpc_return_type_map,
+ )
+ func_name_encoded = b"some_function"
+ data = self.encoder.encode("msg_data")
+
+ worker.execute_rpc = Mock()
+ worker.execute_rpc.return_value = "response"
+ self.encoder.decode_type.return_value = "msg_data"
+
+ response = worker.handle_msg(func_name_encoded, data)
+
+ worker.execute_rpc.assert_called_once_with(
+ func_name_encoded.decode(), "msg_data"
+ )
+ self.encoder.encode.assert_called_with("response")
+ self.assertEqual(response, self.encoder.encode.return_value)
+
+ def test_handle_msg_with_validation_error(self):
+ worker = _Worker(
+ self.rpc_router,
+ self.device_comm_channel,
+ self.encoder,
+ self.rpc_input_type_map,
+ self.rpc_return_type_map,
+ )
+ func_name_encoded = b"some_function"
+ data = b"msg_data"
+ expected_error = "__zerror__validation_error"
+ expected_error_message = "Validation Error"
+ expected_encoded_error = b"encoded_error"
+
+ self.encoder.decode_type.side_effect = msgspec.ValidationError(
+ expected_error_message
+ )
+ self.encoder.encode.return_value = expected_encoded_error
+
+ response = worker.handle_msg(func_name_encoded, data)
+
+ self.encoder.decode_type.assert_called_once_with(data, str)
+ self.encoder.encode.assert_called_once_with(
+ {expected_error: expected_error_message}
+ )
+ self.assertEqual(response, expected_encoded_error)
+ self.assertEqual(response, expected_encoded_error)
+
+ def test_handle_msg_with_server_exception(self):
+ worker = _Worker(
+ self.rpc_router,
+ self.device_comm_channel,
+ self.encoder,
+ self.rpc_input_type_map,
+ self.rpc_return_type_map,
+ )
+ func_name_encoded = b"some_function"
+ data = self.encoder.encode("msg_data")
+
+ worker.execute_rpc = Mock()
+ worker.execute_rpc.side_effect = Exception("Server Exception")
+ self.encoder.decode_type.return_value = "msg_data"
+
+ worker.handle_msg(func_name_encoded, data)
+
+ worker.execute_rpc.assert_called_once_with(
+ func_name_encoded.decode(), "msg_data"
+ )
+ self.encoder.encode.assert_called_with(
+ {"__zerror__server_exception": SERVER_PROCESSING_ERROR}
+ )
diff --git a/zero/codegen/codegen.py b/zero/codegen/codegen.py
index 9ffd6f4..b7449e0 100644
--- a/zero/codegen/codegen.py
+++ b/zero/codegen/codegen.py
@@ -1,86 +1,248 @@
+import datetime
+import decimal
+import enum
import inspect
-
-# from pydantic import BaseModel
+import sys
+import uuid
+from dataclasses import is_dataclass
+from typing import (
+ Callable,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ Union,
+ get_args,
+ get_origin,
+ get_type_hints,
+)
+
+import msgspec
+
+from zero.utils.type_util import typing_types
+
+python_version = sys.version_info
class CodeGen:
- def __init__(self, rpc_router, rpc_input_type_map, rpc_return_type_map):
+ def __init__(
+ self,
+ rpc_router: Dict[str, Tuple[Callable, bool]],
+ rpc_input_type_map: Dict[str, Optional[type]],
+ rpc_return_type_map: Dict[str, Optional[type]],
+ ):
self._rpc_router = rpc_router
self._rpc_input_type_map = rpc_input_type_map
self._rpc_return_type_map = rpc_return_type_map
- self._typing_imports = set()
+
+ # for imports
+ self._typing_imports: List[str] = [
+ str(typ).replace("typing.", "") for typ in typing_types
+ ]
+ self._typing_imports.sort()
+ self._datetime_imports: Set[str] = set()
+ self._has_uuid = False
+ self._has_decimal = False
+ self._has_enum = True
def generate_code(self, host="localhost", port=5559):
code = f"""# Generated by Zero
-# import types as per needed
+# import types as per needed, not all imports are shown here
from zero import ZeroClient
zero_client = ZeroClient("{host}", {port})
-
+"""
+ code += self.generate_models()
+ code += """
class RpcClient:
def __init__(self, zero_client: ZeroClient):
self._zero_client = zero_client
"""
for func_name in self._rpc_router:
+ input_param_name = (
+ None
+ if self._rpc_input_type_map[func_name] is None
+ else self.get_function_input_param_name(func_name)
+ )
code += f"""
{self.get_function_str(func_name)}
- return self._zero_client.call("{func_name}", {
- None if self._rpc_input_type_map[func_name] is None
- else self.get_function_input_param_name(func_name)
- })
+ return self._zero_client.call("{func_name}", {input_param_name})
"""
- # self.generate_data_classes() TODO: next feature
+
+ # add imports after first 2 lines
+ code_lines = code.split("\n")
+ code_lines.insert(2, self.get_imports(code))
+ code = "\n".join(code_lines)
+
+ if "typing." in code:
+ code = code.replace("typing.", "")
+ if "@dataclasses.dataclass" in code:
+ code = code.replace("@dataclasses.dataclass", "@dataclass")
+ if "datetime.datetime" in code:
+ code = code.replace("datetime.datetime", "datetime")
+ if "datetime.date" in code:
+ code = code.replace("datetime.date", "date")
+ if "datetime.time" in code:
+ code = code.replace("datetime.time", "time")
+
return code
- def get_imports(self):
- return f"from typing import {', '.join(i for i in self._typing_imports)}"
+ def get_imports(self, code):
+ for func_name in self._rpc_input_type_map:
+ input_type = self._rpc_input_type_map[func_name]
+ self._track_imports(input_type)
- def get_input_type_str(self, func_name: str): # pragma: no cover
- if self._rpc_input_type_map[func_name] is None:
- return ""
- if self._rpc_input_type_map[func_name].__module__ == "typing":
- type_name = self._rpc_input_type_map[func_name]._name
- self._typing_imports.add(type_name)
- return ": " + type_name
- return ": " + self._rpc_input_type_map[func_name].__name__
-
- def get_return_type_str(self, func_name: str): # pragma: no cover
- if self._rpc_return_type_map[func_name].__module__ == "typing":
- type_name = self._rpc_return_type_map[func_name]._name
- self._typing_imports.add(type_name)
- return type_name
- return self._rpc_return_type_map[func_name].__name__
+ for typ in list(self._typing_imports):
+ if typ + "[" not in code:
+ self._typing_imports.remove(typ)
+
+ import_lines = []
+
+ if "@dataclasses.dataclass" in code or "@dataclass" in code:
+ import_lines.append("from dataclasses import dataclass")
+
+ if self._datetime_imports:
+ import_lines.append(
+ "from datetime import " + ", ".join(sorted(self._datetime_imports))
+ )
+
+ if self._has_decimal:
+ import_lines.append("import decimal")
+ if self._has_enum:
+ import_lines.append("import enum")
+
+ if "(msgspec.Struct)" in code:
+ import_lines.append("import msgspec")
+
+ if "(Struct)" in code:
+ import_lines.append("from msgspec import Struct")
+
+ if self._typing_imports:
+ import_lines.append("from typing import " + ", ".join(self._typing_imports))
+
+ if self._has_uuid:
+ import_lines.append("import uuid")
+
+ return "\n".join(import_lines)
+
+ def _track_imports(self, input_type):
+ if not input_type:
+ return
+ if input_type in (datetime.datetime, datetime.date, datetime.time):
+ self._datetime_imports.add(input_type.__name__)
+ elif input_type == uuid.UUID:
+ self._has_uuid = True
+ elif input_type == decimal.Decimal:
+ self._has_decimal = True
def get_function_str(self, func_name: str):
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
- def_line = [line for line in func_lines if "def" in line][0]
+ func_str = "".join(func_lines)
+ # from def to ->
+ def_str = func_str.split("def")[1].split("->")[0].strip()
+ def_str = "def " + def_str
- # put self after the first (
- def_line = def_line.replace(f"{func_name}(", f"{func_name}(self").replace(
- "async ", ""
- )
+ # Insert 'self' as the first parameter
+ insert_index = def_str.index("(") + 1
+ if self._rpc_input_type_map[func_name]: # If there is input, add 'self, '
+ def_str = def_str[:insert_index] + "self, " + def_str[insert_index:]
+ else: # If there is no input, just add 'self'
+ def_str = def_str[:insert_index] + "self" + def_str[insert_index:]
- # if there is input, add comma after self
- if self._rpc_input_type_map[func_name]:
- def_line = def_line.replace(f"{func_name}(self", f"{func_name}(self, ")
+ # from -> to :
+ return_type_str = func_str.split("->")[1].split(":")[0].strip()
+ # add return type
+ def_str = def_str + f" -> {return_type_str}:"
- return def_line.replace("\n", "")
+ return def_str.strip()
def get_function_input_param_name(self, func_name: str):
func = self._rpc_router[func_name][0]
func_lines = inspect.getsourcelines(func)[0]
- def_line = [line for line in func_lines if "def" in line][0]
- params = def_line.split("(")[1].split(")")[0]
- return params.split(":")[0].strip()
-
- # def generate_data_classes(self):
- # code = ""
- # for func_name in self._rpc_input_type_map:
- # input_class = self._rpc_input_type_map[func_name]
- # if input_class and is_pydantic(input_class):
- # code += inspect.getsource(input_class)
+ func_str = "".join(func_lines)
+ # from bracket to bracket
+ input_param_name = func_str.split("(")[1].split(")")[0]
+ # everything until :
+ input_param_name = input_param_name.split(":")[0]
+ return input_param_name.strip()
+
+ def _generate_class_code(self, cls: Type, already_generated: Set[Type]) -> str:
+ if cls in already_generated:
+ return ""
+
+ code = self._generate_code_for_bases(cls, already_generated)
+ code += self._generate_code_for_fields(cls, already_generated)
+
+ if python_version >= (3, 9):
+ code += inspect.getsource(cls) + "\n\n"
+ else:
+ # python 3.8 doesnt return @dataclass decorator
+ if is_dataclass(cls):
+ code += f"@dataclass\n{inspect.getsource(cls)}\n\n"
+ else:
+ code += inspect.getsource(cls) + "\n\n"
+
+ already_generated.add(cls)
+ return code
+
+ def _generate_code_for_bases(self, cls: Type, already_generated: Set[Type]) -> str:
+ code = ""
+ for base_cls in cls.__bases__:
+ if issubclass(base_cls, msgspec.Struct) and base_cls is not msgspec.Struct:
+ code += self._generate_class_code(base_cls, already_generated)
+ elif is_dataclass(base_cls):
+ code += self._generate_class_code(base_cls, already_generated)
+ return code
+
+ def _generate_code_for_fields(self, cls: Type, already_generated: Set[Type]) -> str:
+ code = ""
+ for field_type in get_type_hints(cls).values():
+ code += self._generate_code_for_type(field_type, already_generated)
+ return code
+
+ def _generate_code_for_type(self, typ: Type, already_generated: Set[Type]) -> str:
+ code = ""
+ typs = self._resolve_field_type(typ)
+ for it in typs:
+ self._track_imports(it)
+ if isinstance(it, type) and (
+ issubclass(it, (msgspec.Struct, enum.Enum, enum.IntEnum))
+ or is_dataclass(it)
+ ):
+ code += self._generate_class_code(it, already_generated)
+ return code
+
+ def _resolve_field_type(self, field_type) -> List[Type]:
+ origin = get_origin(field_type)
+ if origin in (list, tuple, set, frozenset, Optional):
+ return [get_args(field_type)[0]]
+ elif origin == dict:
+ return [get_args(field_type)[1]]
+ elif origin == Union:
+ return list(get_args(field_type))
+
+ return [field_type]
+
+ def generate_models(self) -> str:
+ already_generated: Set[Type] = set()
+ code = ""
+
+ merged_types = list(self._rpc_input_type_map.values()) + list(
+ self._rpc_return_type_map.values()
+ )
+ # retain order and remove duplicates
+ merged_types = list(dict.fromkeys(merged_types))
+
+ for input_type in merged_types:
+ if input_type is None:
+ continue
+ code += self._generate_code_for_type(input_type, already_generated)
+
+ return code
diff --git a/zero/config.py b/zero/config.py
index bd89c09..0ce5feb 100644
--- a/zero/config.py
+++ b/zero/config.py
@@ -11,7 +11,6 @@
RESERVED_FUNCTIONS = ["get_rpc_contract", "connect", "__server_info__"]
ZEROMQ_PATTERN = "proxy"
-ENCODER = "msgspec"
SUPPORTED_PROTOCOLS = {
"zeromq": {
"server": ZMQServer,
diff --git a/zero/encoder/__init__.py b/zero/encoder/__init__.py
index 6b48645..67da175 100644
--- a/zero/encoder/__init__.py
+++ b/zero/encoder/__init__.py
@@ -1,2 +1,3 @@
-from .factory import get_encoder
from .protocols import Encoder
+
+__all__ = ["Encoder"]
diff --git a/zero/encoder/factory.py b/zero/encoder/factory.py
deleted file mode 100644
index 1ed02b2..0000000
--- a/zero/encoder/factory.py
+++ /dev/null
@@ -1,9 +0,0 @@
-from .msgspc import MsgspecEncoder
-from .protocols import Encoder
-
-
-def get_encoder(name: str) -> Encoder:
- if name == "msgspec":
- return MsgspecEncoder()
-
- raise ValueError(f"unknown encoder: {name}")
diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py
index 1b80791..da185c0 100644
--- a/zero/protocols/zeromq/client.py
+++ b/zero/protocols/zeromq/client.py
@@ -1,9 +1,11 @@
import logging
import threading
-from typing import Dict, Optional, Type, TypeVar, Union
+from typing import Dict, Optional, Type, TypeVar
from zero import config
-from zero.encoder import Encoder, get_encoder
+from zero.encoder import Encoder
+from zero.encoder.msgspc import MsgspecEncoder
+from zero.utils.type_util import AllowedType
from zero.zeromq_patterns import (
AsyncZeroMQClient,
ZeroMQClient,
@@ -23,7 +25,7 @@ def __init__(
):
self._address = address
self._default_timeout = default_timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
self.client_pool = ZMQClientPool(
self._address,
@@ -34,7 +36,7 @@ def __init__(
def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> T:
@@ -65,7 +67,7 @@ def __init__(
):
self._address = address
self._default_timeout = default_timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
self.client_pool = AsyncZMQClientPool(
self._address,
@@ -76,7 +78,7 @@ def __init__(
async def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> T:
@@ -114,7 +116,7 @@ def __init__(
self._pool: Dict[int, ZeroMQClient] = {}
self._address = address
self._timeout = timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
def get(self) -> ZeroMQClient:
thread_id = threading.get_ident()
@@ -146,7 +148,7 @@ def __init__(
self._pool: Dict[int, AsyncZeroMQClient] = {}
self._address = address
self._timeout = timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
async def get(self) -> AsyncZeroMQClient:
thread_id = threading.get_ident()
diff --git a/zero/protocols/zeromq/worker.py b/zero/protocols/zeromq/worker.py
index 99b1ebe..7722667 100644
--- a/zero/protocols/zeromq/worker.py
+++ b/zero/protocols/zeromq/worker.py
@@ -1,7 +1,7 @@
import asyncio
import logging
import time
-from typing import Optional
+from typing import Any, Optional
from msgspec import ValidationError
@@ -37,43 +37,48 @@ def __init__(
)
def start_dealer_worker(self, worker_id):
- def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]:
- try:
- func_name = func_name_encoded.decode()
- input_type = self._rpc_input_type_map.get(func_name)
-
- msg = ""
- if data:
- if input_type:
- msg = self._encoder.decode_type(data, input_type)
- else:
- msg = self._encoder.decode(data)
-
- response = self.handle_msg(func_name, msg)
- return self._encoder.encode(response)
- except ValidationError as exc:
- logging.exception(exc)
- return self._encoder.encode({"__zerror__validation_error": str(exc)})
- except Exception as inner_exc: # pylint: disable=broad-except
- logging.exception(inner_exc)
- return self._encoder.encode(
- {"__zerror__server_exception": SERVER_PROCESSING_ERROR}
- )
-
worker = get_worker(config.ZEROMQ_PATTERN, worker_id)
try:
- worker.listen(self._device_comm_channel, process_message)
+ worker.listen(self._device_comm_channel, self.handle_msg)
+
except KeyboardInterrupt:
logging.warning(
"Caught KeyboardInterrupt, terminating worker %d", worker_id
)
+
except Exception as exc: # pylint: disable=broad-except
logging.exception(exc)
+
finally:
logging.warning("Closing worker %d", worker_id)
worker.close()
- def handle_msg(self, rpc, msg):
+ def handle_msg(self, func_name_encoded: bytes, data: bytes) -> Optional[bytes]:
+ try:
+ func_name = func_name_encoded.decode()
+ input_type = self._rpc_input_type_map.get(func_name)
+
+ msg = ""
+ if data:
+ if input_type:
+ msg = self._encoder.decode_type(data, input_type)
+ else:
+ msg = self._encoder.decode(data)
+
+ response = self.execute_rpc(func_name, msg)
+ return self._encoder.encode(response)
+
+ except ValidationError as exc:
+ logging.exception(exc)
+ return self._encoder.encode({"__zerror__validation_error": str(exc)})
+
+ except Exception as inner_exc: # pylint: disable=broad-except
+ logging.exception(inner_exc)
+ return self._encoder.encode(
+ {"__zerror__server_exception": SERVER_PROCESSING_ERROR}
+ )
+
+ def execute_rpc(self, rpc: str, msg: Any):
if rpc == "get_rpc_contract":
return self.generate_rpc_contract(msg)
@@ -88,10 +93,11 @@ def handle_msg(self, rpc, msg):
ret = None
try:
- if is_coro:
- ret = async_to_sync(func)(msg) if msg else async_to_sync(func)()
+ func_to_call = async_to_sync(func) if is_coro else func
+ if self._rpc_input_type_map.get(rpc):
+ ret = func_to_call(msg)
else:
- ret = func(msg) if msg else func()
+ ret = func_to_call()
except Exception as exc: # pylint: disable=broad-except
logging.exception(exc)
@@ -102,6 +108,7 @@ def handle_msg(self, rpc, msg):
def generate_rpc_contract(self, msg):
try:
return self.codegen.generate_code(msg[0], msg[1])
+
except Exception as exc: # pylint: disable=broad-except
logging.exception(exc)
return {"__zerror__failed_to_generate_client_code": str(exc)}
diff --git a/zero/rpc/client.py b/zero/rpc/client.py
index 77d54e8..1595aba 100644
--- a/zero/rpc/client.py
+++ b/zero/rpc/client.py
@@ -1,10 +1,12 @@
-from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
+from typing import TYPE_CHECKING, Optional, Type, TypeVar
from zero import config
-from zero.encoder import Encoder, get_encoder
+from zero.encoder import Encoder
+from zero.encoder.msgspc import MsgspecEncoder
from zero.error import MethodNotFoundException, RemoteException, ValidationException
+from zero.utils.type_util import AllowedType
-if TYPE_CHECKING:
+if TYPE_CHECKING: # pragma: no cover
from zero.rpc.protocols import AsyncZeroClientProtocol, ZeroClientProtocol
T = TypeVar("T")
@@ -55,7 +57,7 @@ def __init__(
"""
self._address = f"tcp://{host}:{port}"
self._default_timeout = default_timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
self._client_inst: "ZeroClientProtocol" = self._determine_client_cls(protocol)(
self._address,
self._default_timeout,
@@ -79,7 +81,7 @@ def _determine_client_cls(self, protocol: str) -> Type["ZeroClientProtocol"]:
def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> T:
@@ -173,7 +175,7 @@ def __init__(
"""
self._address = f"tcp://{host}:{port}"
self._default_timeout = default_timeout
- self._encoder = encoder or get_encoder(config.ENCODER)
+ self._encoder = encoder or MsgspecEncoder()
self._client_inst: "AsyncZeroClientProtocol" = self._determine_client_cls(
"zeromq"
)(
@@ -199,7 +201,7 @@ def _determine_client_cls(self, protocol: str) -> Type["AsyncZeroClientProtocol"
async def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> Optional[T]:
diff --git a/zero/rpc/protocols.py b/zero/rpc/protocols.py
index 3752229..a4f8f5f 100644
--- a/zero/rpc/protocols.py
+++ b/zero/rpc/protocols.py
@@ -6,17 +6,17 @@
Tuple,
Type,
TypeVar,
- Union,
runtime_checkable,
)
from zero.encoder import Encoder
+from zero.utils.type_util import AllowedType
T = TypeVar("T")
@runtime_checkable
-class ZeroServerProtocol(Protocol): # pragma: no cover
+class ZeroServerProtocol(Protocol):
def __init__(
self,
address: str,
@@ -35,7 +35,7 @@ def stop(self):
@runtime_checkable
-class ZeroClientProtocol(Protocol): # pragma: no cover
+class ZeroClientProtocol(Protocol):
def __init__(
self,
address: str,
@@ -47,7 +47,7 @@ def __init__(
def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> Optional[T]:
@@ -58,7 +58,7 @@ def close(self):
@runtime_checkable
-class AsyncZeroClientProtocol(Protocol): # pragma: no cover
+class AsyncZeroClientProtocol(Protocol):
def __init__(
self,
address: str,
@@ -70,7 +70,7 @@ def __init__(
async def call(
self,
rpc_func_name: str,
- msg: Union[int, float, str, dict, list, tuple, None],
+ msg: AllowedType,
timeout: Optional[int] = None,
return_type: Optional[Type[T]] = None,
) -> Optional[T]:
diff --git a/zero/rpc/server.py b/zero/rpc/server.py
index f9cbc14..4f3eba6 100644
--- a/zero/rpc/server.py
+++ b/zero/rpc/server.py
@@ -1,13 +1,24 @@
import logging
import os
from asyncio import iscoroutinefunction
-from typing import TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Coroutine,
+ Dict,
+ Optional,
+ Tuple,
+ Type,
+ Union,
+)
from zero import config
-from zero.encoder import Encoder, get_encoder
+from zero.encoder import Encoder
+from zero.encoder.msgspc import MsgspecEncoder
from zero.utils import type_util
-if TYPE_CHECKING:
+if TYPE_CHECKING: # pragma: no cover
from .protocols import ZeroServerProtocol
# import uvloop
@@ -47,7 +58,11 @@ def __init__(
self._address = f"tcp://{self._host}:{self._port}"
# to encode/decode messages from/to client
- self._encoder = encoder or get_encoder(config.ENCODER)
+ if encoder and not isinstance(encoder, Encoder):
+ raise TypeError(
+ f"encoder should be an instance of Encoder; not {type(encoder)}"
+ )
+ self._encoder = encoder or MsgspecEncoder()
# Stores rpc functions against their names
# and if they are coroutines
@@ -79,7 +94,7 @@ def _determine_server_cls(self, protocol: str) -> Type["ZeroServerProtocol"]:
)
return server_cls
- def register_rpc(self, func: Callable):
+ def register_rpc(self, func: Callable[..., Union[Any, Coroutine]]):
"""
Register a function available for clients.
Function should have a single argument.
diff --git a/zero/utils/type_util.py b/zero/utils/type_util.py
index fd69d4c..9e407f8 100644
--- a/zero/utils/type_util.py
+++ b/zero/utils/type_util.py
@@ -4,7 +4,18 @@
import enum
import typing
import uuid
-from typing import Callable, Optional, get_origin, get_type_hints
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Dict,
+ Optional,
+ Protocol,
+ Type,
+ Union,
+ get_origin,
+ get_type_hints,
+)
import msgspec
@@ -17,15 +28,10 @@
bytes,
bytearray,
tuple,
- typing.Tuple,
list,
- typing.List,
dict,
- typing.Dict,
set,
- typing.Set,
frozenset,
- typing.FrozenSet,
]
std_lib_types: typing.List = [
@@ -40,20 +46,61 @@
]
typing_types: typing.List = [
- typing.Any,
+ typing.Tuple,
+ typing.List,
+ typing.Dict,
+ typing.Set,
+ typing.FrozenSet,
typing.Union,
typing.Optional,
]
msgspec_types: typing.List = [
msgspec.Struct,
- msgspec.Raw,
]
allowed_types = builtin_types + std_lib_types + typing_types
+class IsDataclass(Protocol):
+ # as already noted in comments, checking for this attribute is currently
+ # the most reliable way to ascertain that something is a dataclass
+ __dataclass_fields__: ClassVar[Dict[str, Any]]
+
+
+AllowedType = Union[
+ None,
+ bool,
+ int,
+ float,
+ str,
+ bytes,
+ bytearray,
+ tuple,
+ list,
+ dict,
+ set,
+ frozenset,
+ datetime.datetime,
+ datetime.date,
+ datetime.time,
+ uuid.UUID,
+ decimal.Decimal,
+ enum.Enum,
+ enum.IntEnum,
+ IsDataclass,
+ typing.Tuple,
+ typing.List,
+ typing.Dict,
+ typing.Set,
+ typing.FrozenSet,
+ msgspec.Struct,
+ Type[enum.Enum], # For enum classes
+ Type[enum.IntEnum], # For int enum classes
+]
+
+
def verify_function_args(func: Callable) -> None:
arg_count = func.__code__.co_argcount
if arg_count < 1:
@@ -73,13 +120,6 @@ def verify_function_args(func: Callable) -> None:
def verify_function_return(func: Callable) -> None:
- return_count = func.__code__.co_argcount
- if return_count > 1:
- raise ValueError(
- f"`{func.__name__}` has more than 1 return values; "
- "RPC functions can have only one return value"
- )
-
types = get_type_hints(func)
if not types.get("return"):
raise TypeError(
@@ -106,17 +146,12 @@ def get_function_return_class(func: Callable):
def verify_function_input_type(func: Callable):
input_type = get_function_input_class(func)
- if input_type in allowed_types:
+ if input_type is None:
return
- origin_type = get_origin(input_type)
- if origin_type is not None and origin_type in allowed_types:
+ if is_allowed_type(input_type):
return
- for mtype in msgspec_types:
- if input_type is not None and issubclass(input_type, mtype):
- return
-
raise TypeError(
f"{func.__name__} has type {input_type} which is not allowed; "
"allowed types are: \n" + "\n".join([str(t) for t in allowed_types])
@@ -125,16 +160,21 @@ def verify_function_input_type(func: Callable):
def verify_function_return_type(func: Callable):
return_type = get_function_return_class(func)
- if return_type in allowed_types:
- return
- origin_type = get_origin(return_type)
- if origin_type is not None and origin_type in allowed_types:
- return
+ # None is not allowed as return type
+ if return_type is None:
+ raise TypeError(
+ f"{func.__name__} returns None; RPC functions must return a value"
+ )
- for typ in msgspec_types:
- if issubclass(return_type, typ):
- return
+ # Optional is not allowed as return type
+ if get_origin(return_type) == typing.Union and type(None) in return_type.__args__:
+ raise TypeError(
+ f"{func.__name__} returns Optional; RPC functions must return a value"
+ )
+
+ if is_allowed_type(return_type):
+ return
raise TypeError(
f"{func.__name__} has return type {return_type} which is not allowed; "
@@ -151,22 +191,22 @@ def verify_allowed_type(msg, rpc_method: Optional[str] = None):
)
-def verify_incoming_rpc_call_input_type(
- msg, rpc_method: str, rpc_input_type_map: dict
-): # pragma: no cover
- input_type = rpc_input_type_map[rpc_method]
- if input_type is None:
- return
+def is_allowed_type(typ: Type):
+ if typ in allowed_types:
+ return True
+
+ if str(typ).startswith("