Skip to content

Commit

Permalink
Improve client using asyncio event
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananto30 committed Jun 28, 2024
1 parent 4773c90 commit 04e5369
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 301 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ sanic | 18793.08 | 5.88 | 12739.37 | 8.
zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69
zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80

Seems like blacksheep is the aster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆
Seems like blacksheep is faster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆

# Roadmap 🗺

Expand Down
2 changes: 1 addition & 1 deletion tests/concurrency/rps_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
async def task(semaphore, items):
async with semaphore:
try:
await async_client.call("sum_async", items)
await async_client.call("sum_sync", items)
# res = await async_client.call("sum_async", items)
# print(res)
except Exception as e:
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from zero import ZeroServer
from zero.encoder.protocols import Encoder
from zero.zeromq_patterns.protocols import ZeroMQBroker
from zero.zeromq_patterns.interfaces import ZeroMQBroker

DEFAULT_PORT = 5559
DEFAULT_HOST = "0.0.0.0"
Expand Down Expand Up @@ -216,6 +216,17 @@ class Message:
def add(msg: Message) -> Message:
return Message()

def test_register_rpc_with_long_name(self):
server = ZeroServer()

with self.assertRaises(ValueError):

@server.register_rpc
def add_this_is_a_very_long_name_for_a_function_more_than_120_characters_ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff(
msg: Tuple[int, int]
) -> int:
return msg[0] + msg[1]

def test_server_run(self):
server = ZeroServer()

Expand Down
130 changes: 22 additions & 108 deletions zero/protocols/zeromq/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import asyncio
import logging
import threading
from typing import Any, Dict, Optional, Type, TypeVar, Union
from typing import Dict, Optional, Type, TypeVar, Union

from zero import config
from zero.encoder import Encoder, get_encoder
from zero.error import TimeoutException
from zero.utils import util
from zero.zeromq_patterns import (
AsyncZeroMQClient,
ZeroMQClient,
Expand All @@ -28,7 +25,7 @@ def __init__(
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)

self.client_pool = ZeroMQClientPool(
self.client_pool = ZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -43,47 +40,17 @@ def call(
) -> T:
zmqc = self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout

def _poll_data():
# TODO poll is slow, need to find a better way
if not zmqc.poll(_timeout):
raise TimeoutException(
f"Timeout while sending message at {self._address}"
)

rcv_data = zmqc.recv()

# first 32 bytes as response id
resp_id = rcv_data[:32].decode()

# the rest is response data
resp_data_encoded = rcv_data[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)

return resp_id, resp_data

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

# make function name exactly 80 bytes
func_name_bytes = rpc_func_name.ljust(80).encode()
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

resp_id, resp_data = None, None
# as the client is synchronous, we know that the response will be available any next poll
# we try to get the response until timeout because a previous call might be timed out
# and the response is still in the socket,
# so we poll until we get the response for this call
while resp_id != req_id:
resp_id, resp_data = _poll_data()
resp_data_bytes = zmqc.request(func_name_bytes + msg_bytes, timeout)

return resp_data # type: ignore
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
Expand All @@ -99,9 +66,8 @@ def __init__(
self._address = address
self._default_timeout = default_timeout
self._encoder = encoder or get_encoder(config.ENCODER)
self._resp_map: Dict[str, Any] = {}

self.client_pool = AsyncZeroMQClientPool(
self.client_pool = AsyncZMQClientPool(
self._address,
self._default_timeout,
self._encoder,
Expand All @@ -116,63 +82,23 @@ async def call(
) -> T:
zmqc = await self.client_pool.get()

_timeout = self._default_timeout if timeout is None else timeout
expire_at = util.current_time_us() + (_timeout * 1000)

async def _poll_data():
# TODO async has issue with poller, after 3-4 calls, it returns empty
# if not await zmqc.poll(_timeout):
# raise TimeoutException(f"Timeout while sending message at {self._address}")

# first 32 bytes as response id
resp = await zmqc.recv()
resp_id = resp[:32].decode()

# the rest is response data
resp_data_encoded = resp[32:]
resp_data = (
self._encoder.decode(resp_data_encoded)
if return_type is None
else self._encoder.decode_type(resp_data_encoded, return_type)
)
self._resp_map[resp_id] = resp_data

# TODO try to use pipe instead of sleep
# await self.peer1.send(b"")

req_id = util.unique_id()

# function name exactly 120 bytes
func_name_bytes = rpc_func_name.ljust(120).encode()

# make function name exactly 80 bytes
func_name_bytes = rpc_func_name.ljust(80).encode()
msg_bytes = b"" if msg is None else self._encoder.encode(msg)
await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes)

# every request poll the data, so whenever a response comes, it will be stored in __resps
# dont need to poll again in the while loop
await _poll_data()

while req_id not in self._resp_map and util.current_time_us() <= expire_at:
# TODO the problem with the zpipe is that we can miss some response
# when we come to this line
# await self.peer2.recv()
await asyncio.sleep(1e-6)
resp_data_bytes = await zmqc.request(func_name_bytes + msg_bytes, timeout)

if util.current_time_us() > expire_at:
raise TimeoutException(
f"Timeout while waiting for response at {self._address}"
)

resp_data = self._resp_map.pop(req_id)

return resp_data
return (
self._encoder.decode(resp_data_bytes)
if return_type is None
else self._encoder.decode_type(resp_data_bytes, return_type)
)

def close(self):
self.client_pool.close()
self._resp_map = {}


class ZeroMQClientPool:
class ZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient:
logging.debug("No connection found in current thread, creating new one")
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
self._pool[thread_id].connect(self._address)
self._try_connect_ping(self._pool[thread_id])
return self._pool[thread_id]

def _try_connect_ping(self, client: ZeroMQClient):
client.send(util.unique_id().encode() + b"connect" + b"")
client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
self._pool = {}


class AsyncZeroMQClientPool:
class AsyncZMQClientPool:
"""
Connections are based on different threads and processes.
Each time a call is made it tries to get the connection from the pool,
Expand All @@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient:
self._pool[thread_id] = get_async_client(
config.ZEROMQ_PATTERN, self._timeout
)
self._pool[thread_id].connect(self._address)
await self._try_connect_ping(self._pool[thread_id])
await self._pool[thread_id].connect(self._address)
return self._pool[thread_id]

async def _try_connect_ping(self, client: AsyncZeroMQClient):
await client.send(util.unique_id().encode() + b"connect" + b"")
await client.recv()
logging.info("Connected to server at %s", self._address)

def close(self):
for client in self._pool.values():
client.close()
Expand Down
7 changes: 3 additions & 4 deletions zero/protocols/zeromq/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,14 @@ def start(self, workers: int = os.cpu_count() or 1):

self._start_server(workers, spawn_worker)

def _start_server(self, workers: int, spawn_worker: Callable):
def _start_server(self, workers: int, spawn_worker: Callable[[int], None]):
self._pool = Pool(workers)

# process termination signals
util.register_signal_term(self._sig_handler)

# TODO: by default we start the workers with processes,
# but we need support to run only router, without workers
self._pool.map_async(spawn_worker, list(range(1, workers + 1)))
worker_ids = list(range(1, workers + 1))
self._pool.map_async(spawn_worker, worker_ids)

# blocking
with zmq.utils.win32.allow_interrupt(self.stop):
Expand Down
4 changes: 1 addition & 3 deletions zero/protocols/zeromq/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]:
except ValidationError as exc:
logging.exception(exc)
return self._encoder.encode({"__zerror__validation_error": str(exc)})
except (
Exception
) as inner_exc: # pragma: no cover pylint: disable=broad-except
except Exception as inner_exc: # pylint: disable=broad-except
logging.exception(inner_exc)
return self._encoder.encode(
{"__zerror__server_exception": SERVER_PROCESSING_ERROR}
Expand Down
6 changes: 4 additions & 2 deletions zero/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def register_rpc(self, func: Callable):
Function should have a single argument.
Argument and return should have a type hint.
If the function got exception, client will get None as return value.
Parameters
----------
func: Callable
Expand Down Expand Up @@ -135,6 +133,10 @@ def run(self, workers: int = os.cpu_count() or 1):
def _verify_function_name(self, func):
if not isinstance(func, Callable):
raise ValueError(f"register function; not {type(func)}")
if len(func.__name__) > 80:
raise ValueError(
"function name can be at max 80" f" characters; {func.__name__}"
)
if func.__name__ in self._rpc_router:
raise ValueError(
f"cannot have two RPC function same name: `{func.__name__}`"
Expand Down
22 changes: 11 additions & 11 deletions zero/utils/async_to_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
import threading
from functools import wraps

_loop = None
_thrd = None
_LOOP = None
_THRD = None


def start_async_loop():
global _loop, _thrd
if _loop is None or _thrd is None or not _thrd.is_alive():
_loop = asyncio.new_event_loop()
_thrd = threading.Thread(
target=_loop.run_forever, name="Async Runner", daemon=True
global _LOOP, _THRD # pylint: disable=global-statement
if _LOOP is None or _THRD is None or not _THRD.is_alive():
_LOOP = asyncio.new_event_loop()
_THRD = threading.Thread(
target=_LOOP.run_forever, name="Async Runner", daemon=True
)
_thrd.start()
_THRD.start()


def async_to_sync(func):
@wraps(func)
def run(*args, **kwargs):
start_async_loop() # Ensure the loop and thread are started
try:
future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _loop)
future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _LOOP)
return future.result()
except Exception as e:
print(f"Exception occurred: {e}")
except Exception as exc:
print(f"Exception occurred: {exc}")
raise

return run
4 changes: 2 additions & 2 deletions zero/utils/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ def verify_function_return_type(func: Callable):
if origin_type is not None and origin_type in allowed_types:
return

for t in msgspec_types:
if issubclass(return_type, t):
for typ in msgspec_types:
if issubclass(return_type, typ):
return

raise TypeError(
Expand Down
2 changes: 1 addition & 1 deletion zero/zeromq_patterns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .factory import get_async_client, get_broker, get_client, get_worker
from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
2 changes: 1 addition & 1 deletion zero/zeromq_patterns/factory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from zero.zeromq_patterns import queue_device

from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker
from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker


def get_client(pattern: str, default_timeout: int = 2000) -> ZeroMQClient:
Expand Down
6 changes: 3 additions & 3 deletions zero/zeromq_patterns/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def zpipe_async(
ctx: zmq.asyncio.Context, timeout: int = 1000
ctx: zmq.asyncio.Context,
) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: # pragma: no cover
"""
Build inproc pipe for talking to threads
Expand All @@ -20,8 +20,8 @@ def zpipe_async(
sock_b = ctx.socket(zmq.PAIR)
sock_a.linger = sock_b.linger = 0
sock_a.hwm = sock_b.hwm = 1
sock_a.sndtimeo = sock_b.sndtimeo = timeout
sock_a.rcvtimeo = sock_b.rcvtimeo = timeout
sock_a.sndtimeo = sock_b.sndtimeo = 0
sock_a.rcvtimeo = sock_b.rcvtimeo = 0
iface = f"inproc://{util.unique_id()}"
sock_a.bind(iface)
sock_b.connect(iface)
Expand Down
Loading

0 comments on commit 04e5369

Please sign in to comment.