From 04e53690d8aef4758aaca4170cfe10e7901d2ce5 Mon Sep 17 00:00:00 2001 From: Azizul Haque Ananto Date: Tue, 25 Jun 2024 20:43:51 +0200 Subject: [PATCH] Improve client using asyncio event --- README.md | 2 +- tests/concurrency/rps_async.py | 2 +- tests/unit/test_server.py | 13 +- zero/protocols/zeromq/client.py | 130 +++----------- zero/protocols/zeromq/server.py | 7 +- zero/protocols/zeromq/worker.py | 4 +- zero/rpc/server.py | 6 +- zero/utils/async_to_sync.py | 22 +-- zero/utils/type_util.py | 4 +- zero/zeromq_patterns/__init__.py | 2 +- zero/zeromq_patterns/factory.py | 2 +- zero/zeromq_patterns/helpers.py | 6 +- zero/zeromq_patterns/interfaces.py | 59 +++++++ zero/zeromq_patterns/protocols.py | 86 --------- zero/zeromq_patterns/queue_device/broker.py | 2 +- zero/zeromq_patterns/queue_device/client.py | 184 ++++++++++++-------- zero/zeromq_patterns/queue_device/worker.py | 11 +- 17 files changed, 241 insertions(+), 301 deletions(-) create mode 100644 zero/zeromq_patterns/interfaces.py delete mode 100644 zero/zeromq_patterns/protocols.py diff --git a/README.md b/README.md index 599af3a..1754f17 100644 --- a/README.md +++ b/README.md @@ -257,7 +257,7 @@ sanic | 18793.08 | 5.88 | 12739.37 | 8. zero(sync) | 28471.47 | 4.12 | 18114.84 | 6.69 zero(async) | 29012.03 | 3.43 | 20956.48 | 5.80 -Seems like blacksheep is the aster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆 +Seems like blacksheep is faster on hello world, but in more complex operations like saving to redis, zero is the winner! 🏆 # Roadmap 🗺 diff --git a/tests/concurrency/rps_async.py b/tests/concurrency/rps_async.py index b8ad5e1..bd4f1a0 100644 --- a/tests/concurrency/rps_async.py +++ b/tests/concurrency/rps_async.py @@ -11,7 +11,7 @@ async def task(semaphore, items): async with semaphore: try: - await async_client.call("sum_async", items) + await async_client.call("sum_sync", items) # res = await async_client.call("sum_async", items) # print(res) except Exception as e: diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index f09fd9e..7790820 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -8,7 +8,7 @@ from zero import ZeroServer from zero.encoder.protocols import Encoder -from zero.zeromq_patterns.protocols import ZeroMQBroker +from zero.zeromq_patterns.interfaces import ZeroMQBroker DEFAULT_PORT = 5559 DEFAULT_HOST = "0.0.0.0" @@ -216,6 +216,17 @@ class Message: def add(msg: Message) -> Message: return Message() + def test_register_rpc_with_long_name(self): + server = ZeroServer() + + with self.assertRaises(ValueError): + + @server.register_rpc + def add_this_is_a_very_long_name_for_a_function_more_than_120_characters_ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff( + msg: Tuple[int, int] + ) -> int: + return msg[0] + msg[1] + def test_server_run(self): server = ZeroServer() diff --git a/zero/protocols/zeromq/client.py b/zero/protocols/zeromq/client.py index a78598d..1b80791 100644 --- a/zero/protocols/zeromq/client.py +++ b/zero/protocols/zeromq/client.py @@ -1,12 +1,9 @@ -import asyncio import logging import threading -from typing import Any, Dict, Optional, Type, TypeVar, Union +from typing import Dict, Optional, Type, TypeVar, Union from zero import config from zero.encoder import Encoder, get_encoder -from zero.error import TimeoutException -from zero.utils import util from zero.zeromq_patterns import ( AsyncZeroMQClient, ZeroMQClient, @@ -28,7 +25,7 @@ def __init__( self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) - self.client_pool = ZeroMQClientPool( + self.client_pool = ZMQClientPool( self._address, self._default_timeout, self._encoder, @@ -43,47 +40,17 @@ def call( ) -> T: zmqc = self.client_pool.get() - _timeout = self._default_timeout if timeout is None else timeout - - def _poll_data(): - # TODO poll is slow, need to find a better way - if not zmqc.poll(_timeout): - raise TimeoutException( - f"Timeout while sending message at {self._address}" - ) - - rcv_data = zmqc.recv() - - # first 32 bytes as response id - resp_id = rcv_data[:32].decode() - - # the rest is response data - resp_data_encoded = rcv_data[32:] - resp_data = ( - self._encoder.decode(resp_data_encoded) - if return_type is None - else self._encoder.decode_type(resp_data_encoded, return_type) - ) - - return resp_id, resp_data - - req_id = util.unique_id() - - # function name exactly 120 bytes - func_name_bytes = rpc_func_name.ljust(120).encode() - + # make function name exactly 80 bytes + func_name_bytes = rpc_func_name.ljust(80).encode() msg_bytes = b"" if msg is None else self._encoder.encode(msg) - zmqc.send(req_id.encode() + func_name_bytes + msg_bytes) - resp_id, resp_data = None, None - # as the client is synchronous, we know that the response will be available any next poll - # we try to get the response until timeout because a previous call might be timed out - # and the response is still in the socket, - # so we poll until we get the response for this call - while resp_id != req_id: - resp_id, resp_data = _poll_data() + resp_data_bytes = zmqc.request(func_name_bytes + msg_bytes, timeout) - return resp_data # type: ignore + return ( + self._encoder.decode(resp_data_bytes) + if return_type is None + else self._encoder.decode_type(resp_data_bytes, return_type) + ) def close(self): self.client_pool.close() @@ -99,9 +66,8 @@ def __init__( self._address = address self._default_timeout = default_timeout self._encoder = encoder or get_encoder(config.ENCODER) - self._resp_map: Dict[str, Any] = {} - self.client_pool = AsyncZeroMQClientPool( + self.client_pool = AsyncZMQClientPool( self._address, self._default_timeout, self._encoder, @@ -116,63 +82,23 @@ async def call( ) -> T: zmqc = await self.client_pool.get() - _timeout = self._default_timeout if timeout is None else timeout - expire_at = util.current_time_us() + (_timeout * 1000) - - async def _poll_data(): - # TODO async has issue with poller, after 3-4 calls, it returns empty - # if not await zmqc.poll(_timeout): - # raise TimeoutException(f"Timeout while sending message at {self._address}") - - # first 32 bytes as response id - resp = await zmqc.recv() - resp_id = resp[:32].decode() - - # the rest is response data - resp_data_encoded = resp[32:] - resp_data = ( - self._encoder.decode(resp_data_encoded) - if return_type is None - else self._encoder.decode_type(resp_data_encoded, return_type) - ) - self._resp_map[resp_id] = resp_data - - # TODO try to use pipe instead of sleep - # await self.peer1.send(b"") - - req_id = util.unique_id() - - # function name exactly 120 bytes - func_name_bytes = rpc_func_name.ljust(120).encode() - + # make function name exactly 80 bytes + func_name_bytes = rpc_func_name.ljust(80).encode() msg_bytes = b"" if msg is None else self._encoder.encode(msg) - await zmqc.send(req_id.encode() + func_name_bytes + msg_bytes) - - # every request poll the data, so whenever a response comes, it will be stored in __resps - # dont need to poll again in the while loop - await _poll_data() - while req_id not in self._resp_map and util.current_time_us() <= expire_at: - # TODO the problem with the zpipe is that we can miss some response - # when we come to this line - # await self.peer2.recv() - await asyncio.sleep(1e-6) + resp_data_bytes = await zmqc.request(func_name_bytes + msg_bytes, timeout) - if util.current_time_us() > expire_at: - raise TimeoutException( - f"Timeout while waiting for response at {self._address}" - ) - - resp_data = self._resp_map.pop(req_id) - - return resp_data + return ( + self._encoder.decode(resp_data_bytes) + if return_type is None + else self._encoder.decode_type(resp_data_bytes, return_type) + ) def close(self): self.client_pool.close() - self._resp_map = {} -class ZeroMQClientPool: +class ZMQClientPool: """ Connections are based on different threads and processes. Each time a call is made it tries to get the connection from the pool, @@ -196,21 +122,15 @@ def get(self) -> ZeroMQClient: logging.debug("No connection found in current thread, creating new one") self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout) self._pool[thread_id].connect(self._address) - self._try_connect_ping(self._pool[thread_id]) return self._pool[thread_id] - def _try_connect_ping(self, client: ZeroMQClient): - client.send(util.unique_id().encode() + b"connect" + b"") - client.recv() - logging.info("Connected to server at %s", self._address) - def close(self): for client in self._pool.values(): client.close() self._pool = {} -class AsyncZeroMQClientPool: +class AsyncZMQClientPool: """ Connections are based on different threads and processes. Each time a call is made it tries to get the connection from the pool, @@ -235,15 +155,9 @@ async def get(self) -> AsyncZeroMQClient: self._pool[thread_id] = get_async_client( config.ZEROMQ_PATTERN, self._timeout ) - self._pool[thread_id].connect(self._address) - await self._try_connect_ping(self._pool[thread_id]) + await self._pool[thread_id].connect(self._address) return self._pool[thread_id] - async def _try_connect_ping(self, client: AsyncZeroMQClient): - await client.send(util.unique_id().encode() + b"connect" + b"") - await client.recv() - logging.info("Connected to server at %s", self._address) - def close(self): for client in self._pool.values(): client.close() diff --git a/zero/protocols/zeromq/server.py b/zero/protocols/zeromq/server.py index 766a506..a47dd22 100644 --- a/zero/protocols/zeromq/server.py +++ b/zero/protocols/zeromq/server.py @@ -66,15 +66,14 @@ def start(self, workers: int = os.cpu_count() or 1): self._start_server(workers, spawn_worker) - def _start_server(self, workers: int, spawn_worker: Callable): + def _start_server(self, workers: int, spawn_worker: Callable[[int], None]): self._pool = Pool(workers) # process termination signals util.register_signal_term(self._sig_handler) - # TODO: by default we start the workers with processes, - # but we need support to run only router, without workers - self._pool.map_async(spawn_worker, list(range(1, workers + 1))) + worker_ids = list(range(1, workers + 1)) + self._pool.map_async(spawn_worker, worker_ids) # blocking with zmq.utils.win32.allow_interrupt(self.stop): diff --git a/zero/protocols/zeromq/worker.py b/zero/protocols/zeromq/worker.py index 4fb6139..99b1ebe 100644 --- a/zero/protocols/zeromq/worker.py +++ b/zero/protocols/zeromq/worker.py @@ -54,9 +54,7 @@ def process_message(func_name_encoded: bytes, data: bytes) -> Optional[bytes]: except ValidationError as exc: logging.exception(exc) return self._encoder.encode({"__zerror__validation_error": str(exc)}) - except ( - Exception - ) as inner_exc: # pragma: no cover pylint: disable=broad-except + except Exception as inner_exc: # pylint: disable=broad-except logging.exception(inner_exc) return self._encoder.encode( {"__zerror__server_exception": SERVER_PROCESSING_ERROR} diff --git a/zero/rpc/server.py b/zero/rpc/server.py index e05f3a1..f9cbc14 100644 --- a/zero/rpc/server.py +++ b/zero/rpc/server.py @@ -85,8 +85,6 @@ def register_rpc(self, func: Callable): Function should have a single argument. Argument and return should have a type hint. - If the function got exception, client will get None as return value. - Parameters ---------- func: Callable @@ -135,6 +133,10 @@ def run(self, workers: int = os.cpu_count() or 1): def _verify_function_name(self, func): if not isinstance(func, Callable): raise ValueError(f"register function; not {type(func)}") + if len(func.__name__) > 80: + raise ValueError( + "function name can be at max 80" f" characters; {func.__name__}" + ) if func.__name__ in self._rpc_router: raise ValueError( f"cannot have two RPC function same name: `{func.__name__}`" diff --git a/zero/utils/async_to_sync.py b/zero/utils/async_to_sync.py index 770986e..aac8fb5 100644 --- a/zero/utils/async_to_sync.py +++ b/zero/utils/async_to_sync.py @@ -2,18 +2,18 @@ import threading from functools import wraps -_loop = None -_thrd = None +_LOOP = None +_THRD = None def start_async_loop(): - global _loop, _thrd - if _loop is None or _thrd is None or not _thrd.is_alive(): - _loop = asyncio.new_event_loop() - _thrd = threading.Thread( - target=_loop.run_forever, name="Async Runner", daemon=True + global _LOOP, _THRD # pylint: disable=global-statement + if _LOOP is None or _THRD is None or not _THRD.is_alive(): + _LOOP = asyncio.new_event_loop() + _THRD = threading.Thread( + target=_LOOP.run_forever, name="Async Runner", daemon=True ) - _thrd.start() + _THRD.start() def async_to_sync(func): @@ -21,10 +21,10 @@ def async_to_sync(func): def run(*args, **kwargs): start_async_loop() # Ensure the loop and thread are started try: - future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _loop) + future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _LOOP) return future.result() - except Exception as e: - print(f"Exception occurred: {e}") + except Exception as exc: + print(f"Exception occurred: {exc}") raise return run diff --git a/zero/utils/type_util.py b/zero/utils/type_util.py index b64ffbd..fd69d4c 100644 --- a/zero/utils/type_util.py +++ b/zero/utils/type_util.py @@ -132,8 +132,8 @@ def verify_function_return_type(func: Callable): if origin_type is not None and origin_type in allowed_types: return - for t in msgspec_types: - if issubclass(return_type, t): + for typ in msgspec_types: + if issubclass(return_type, typ): return raise TypeError( diff --git a/zero/zeromq_patterns/__init__.py b/zero/zeromq_patterns/__init__.py index 125c8a6..9b05bc0 100644 --- a/zero/zeromq_patterns/__init__.py +++ b/zero/zeromq_patterns/__init__.py @@ -1,2 +1,2 @@ from .factory import get_async_client, get_broker, get_client, get_worker -from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker +from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker diff --git a/zero/zeromq_patterns/factory.py b/zero/zeromq_patterns/factory.py index f59d02c..e5a1abd 100644 --- a/zero/zeromq_patterns/factory.py +++ b/zero/zeromq_patterns/factory.py @@ -1,6 +1,6 @@ from zero.zeromq_patterns import queue_device -from .protocols import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker +from .interfaces import AsyncZeroMQClient, ZeroMQBroker, ZeroMQClient, ZeroMQWorker def get_client(pattern: str, default_timeout: int = 2000) -> ZeroMQClient: diff --git a/zero/zeromq_patterns/helpers.py b/zero/zeromq_patterns/helpers.py index 404c66b..d6bd441 100644 --- a/zero/zeromq_patterns/helpers.py +++ b/zero/zeromq_patterns/helpers.py @@ -7,7 +7,7 @@ def zpipe_async( - ctx: zmq.asyncio.Context, timeout: int = 1000 + ctx: zmq.asyncio.Context, ) -> Tuple[zmq.asyncio.Socket, zmq.asyncio.Socket]: # pragma: no cover """ Build inproc pipe for talking to threads @@ -20,8 +20,8 @@ def zpipe_async( sock_b = ctx.socket(zmq.PAIR) sock_a.linger = sock_b.linger = 0 sock_a.hwm = sock_b.hwm = 1 - sock_a.sndtimeo = sock_b.sndtimeo = timeout - sock_a.rcvtimeo = sock_b.rcvtimeo = timeout + sock_a.sndtimeo = sock_b.sndtimeo = 0 + sock_a.rcvtimeo = sock_b.rcvtimeo = 0 iface = f"inproc://{util.unique_id()}" sock_a.bind(iface) sock_b.connect(iface) diff --git a/zero/zeromq_patterns/interfaces.py b/zero/zeromq_patterns/interfaces.py new file mode 100644 index 0000000..8267008 --- /dev/null +++ b/zero/zeromq_patterns/interfaces.py @@ -0,0 +1,59 @@ +from typing import Callable, Optional, Protocol, runtime_checkable + + +@runtime_checkable +class ZeroMQClient(Protocol): # pragma: no cover + def __init__( + self, + address: str, + default_timeout: int = 2000, + ): + ... + + def connect(self, address: str) -> None: + ... + + def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: + ... + + def close(self) -> None: + ... + + +@runtime_checkable +class AsyncZeroMQClient(Protocol): # pragma: no cover + def __init__( + self, + address: str, + default_timeout: int = 2000, + ): + ... + + async def connect(self, address: str) -> None: + ... + + async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: + ... + + def close(self) -> None: + ... + + +@runtime_checkable +class ZeroMQBroker(Protocol): # pragma: no cover + def listen(self, address: str, channel: str) -> None: + ... + + def close(self) -> None: + ... + + +@runtime_checkable +class ZeroMQWorker(Protocol): # pragma: no cover + def listen( + self, address: str, msg_handler: Callable[[bytes, bytes], Optional[bytes]] + ) -> None: + ... + + def close(self) -> None: + ... diff --git a/zero/zeromq_patterns/protocols.py b/zero/zeromq_patterns/protocols.py deleted file mode 100644 index 3af0d1c..0000000 --- a/zero/zeromq_patterns/protocols.py +++ /dev/null @@ -1,86 +0,0 @@ -from typing import Any, Callable, Optional, Protocol, runtime_checkable - -import zmq -import zmq.asyncio - - -@runtime_checkable -class ZeroMQClient(Protocol): # pragma: no cover - @property - def context(self) -> zmq.Context: - ... - - def connect(self, address: str) -> None: - ... - - def close(self) -> None: - ... - - def send(self, message: bytes) -> None: - ... - - def send_multipart(self, message: list) -> None: - ... - - def poll(self, timeout: int) -> bool: - ... - - def recv(self) -> bytes: - ... - - def recv_multipart(self) -> list: - ... - - def request(self, message: bytes) -> Any: - ... - - -@runtime_checkable -class AsyncZeroMQClient(Protocol): # pragma: no cover - @property - def context(self) -> zmq.asyncio.Context: - ... - - def connect(self, address: str) -> None: - ... - - def close(self) -> None: - ... - - async def send(self, message: bytes) -> None: - ... - - async def send_multipart(self, message: list) -> None: - ... - - async def poll(self, timeout: int) -> bool: - ... - - async def recv(self) -> bytes: - ... - - async def recv_multipart(self) -> list: - ... - - async def request(self, message: bytes) -> Any: - ... - - -@runtime_checkable -class ZeroMQBroker(Protocol): # pragma: no cover - def listen(self, address: str, channel: str) -> None: - ... - - def close(self) -> None: - ... - - -@runtime_checkable -class ZeroMQWorker(Protocol): # pragma: no cover - def listen( - self, address: str, msg_handler: Callable[[bytes, bytes], Optional[bytes]] - ) -> None: - ... - - def close(self) -> None: - ... diff --git a/zero/zeromq_patterns/queue_device/broker.py b/zero/zeromq_patterns/queue_device/broker.py index 7d5645b..dd87af0 100644 --- a/zero/zeromq_patterns/queue_device/broker.py +++ b/zero/zeromq_patterns/queue_device/broker.py @@ -1,7 +1,7 @@ import logging import zmq -from zmq import proxy +from zmq.backend import proxy class ZeroMQBroker: diff --git a/zero/zeromq_patterns/queue_device/client.py b/zero/zeromq_patterns/queue_device/client.py index 997ce13..f159c39 100644 --- a/zero/zeromq_patterns/queue_device/client.py +++ b/zero/zeromq_patterns/queue_device/client.py @@ -1,12 +1,15 @@ import asyncio +import logging import sys -from typing import Optional +from asyncio import Event +from typing import Dict, Optional import zmq import zmq.asyncio as zmqasync import zmq.error as zmqerr from zero.error import ConnectionException, TimeoutException +from zero.utils import util class ZeroMQClient: @@ -23,38 +26,62 @@ def __init__(self, default_timeout): self.poller = zmq.Poller() self.poller.register(self.socket, zmq.POLLIN) - @property - def context(self): - return self._context - def connect(self, address: str) -> None: self._address = address self.socket.connect(address) + self._send(util.unique_id().encode() + b"connect" + b"") + self._recv() + logging.info("Connected to server at %s", self._address) + + def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: + _timeout = self._default_timeout if timeout is None else timeout + + def _poll_data(): + # poll is slow, need to find a better way + if not self._poll(_timeout): + raise TimeoutException( + f"Timeout while sending message at {self._address}" + ) + + rcv_data = self._recv() + + # first 32 bytes as response id + resp_id = rcv_data[:32].decode() + + # the rest is response data + resp_data = rcv_data[32:] + + return resp_id, resp_data + + req_id = util.unique_id() + self._send(req_id.encode() + message) + + resp_id, resp_data = None, None + # as the client is synchronous, we know that the response will be available any next poll + # we try to get the response until timeout because a previous call might be timed out + # and the response is still in the socket, + # so we poll until we get the response for this call + while resp_id != req_id: + resp_id, resp_data = _poll_data() + + return resp_data # type: ignore def close(self) -> None: self.socket.close() - def send(self, message: bytes) -> None: - try: - self.socket.send(message, zmq.DONTWAIT) - except zmqerr.Again as exc: - raise ConnectionException( - f"Connection error for send at {self._address}" - ) from exc - - def send_multipart(self, message: list) -> None: + def _send(self, message: bytes) -> None: try: - self.socket.send_multipart(message, copy=False) + self.socket.send(message, zmq.NOBLOCK) except zmqerr.Again as exc: raise ConnectionException( f"Connection error for send at {self._address}" ) from exc - def poll(self, timeout: int) -> bool: + def _poll(self, timeout: int) -> bool: socks = dict(self.poller.poll(timeout)) return self.socket in socks - def recv(self) -> bytes: + def _recv(self) -> bytes: try: return self.socket.recv() except zmqerr.Again as exc: @@ -62,25 +89,6 @@ def recv(self) -> bytes: f"Connection error for recv at {self._address}" ) from exc - def recv_multipart(self) -> list: - try: - return self.socket.recv_multipart() - except zmqerr.Again as exc: - raise ConnectionException( - f"Connection error for recv at {self._address}" - ) from exc - - def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: - try: - self.send(message) - if self.poll(timeout or self._default_timeout): - return self.recv() - raise TimeoutException(f"Timeout waiting for response from {self._address}") - except zmqerr.Again as exc: - raise ConnectionException( - f"Connection error for request at {self._address}" - ) from exc - class AsyncZeroMQClient: def __init__(self, default_timeout: int = 2000): @@ -100,60 +108,94 @@ def __init__(self, default_timeout: int = 2000): self.poller = zmqasync.Poller() self.poller.register(self.socket, zmq.POLLIN) - @property - def context(self): - return self._context + self._resp_map: Dict[str, bytes] = {} - def connect(self, address: str) -> None: + # self.peer1, self.peer2 = zpipe_async(self._context) + + async def connect(self, address: str) -> None: self._address = address self.socket.connect(address) + await self._send(util.unique_id().encode() + b"connect" + b"") + await self._recv() + logging.info("Connected to server at %s", self._address) + + async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: + _timeout = self._default_timeout if timeout is None else timeout + expire_at = util.current_time_us() + (_timeout * 1000) + + is_data = Event() + + async def _poll_data(): + # async has issue with poller, after 3-4 calls, it returns empty + # if not await self._poll(_timeout): + # raise TimeoutException(f"Timeout while sending message at {self._address}") + + resp = await self._recv() + + # first 32 bytes as response id + resp_id = resp[:32].decode() + + # the rest is response data + resp_data = resp[32:] + self._resp_map[resp_id] = resp_data + + # pipe is a good way to notify the main event loop that there is a response + # but pipe is actually slower than sleep, because it is a zmq socket + # yes it uses inproc, but still slower than asyncio.sleep + # try: + # await self.peer1.send(b"") + # except zmqerr.Again: + # # if the pipe is full, just pass + # pass + + is_data.set() + + req_id = util.unique_id() + await self._send(req_id.encode() + message) + + # poll can get response of a different call + # so we poll until we get the response of this call or timeout + await _poll_data() + + while req_id not in self._resp_map: + if util.current_time_us() > expire_at: + raise TimeoutException( + f"Timeout while waiting for response at {self._address}" + ) + + # await asyncio.sleep(1e-6) + await asyncio.wait_for(is_data.wait(), timeout=_timeout) + + # try: + # await self.peer2.recv() + # except zmqerr.Again: + # # if the pipe is empty, just pass + # pass + + resp_data = self._resp_map.pop(req_id) + + return resp_data def close(self) -> None: self.socket.close() + self._resp_map.clear() - async def send(self, message: bytes) -> None: - try: - await self.socket.send(message, zmq.DONTWAIT) - except zmqerr.Again as exc: - raise ConnectionException( - f"Connection error for send at {self._address}" - ) from exc - - async def send_multipart(self, message: list) -> None: + async def _send(self, message: bytes) -> None: try: - await self.socket.send_multipart(message, copy=False) + await self.socket.send(message, zmq.NOBLOCK) except zmqerr.Again as exc: raise ConnectionException( f"Connection error for send at {self._address}" ) from exc - async def poll(self, timeout: int) -> bool: + async def _poll(self, timeout: int) -> bool: socks = dict(await self.poller.poll(timeout)) return self.socket in socks - async def recv(self) -> bytes: + async def _recv(self) -> bytes: try: return await self.socket.recv() except zmqerr.Again as exc: raise ConnectionException( f"Connection error for recv at {self._address}" ) from exc - - async def recv_multipart(self) -> list: - try: - return await self.socket.recv_multipart() - except zmqerr.Again as exc: - raise ConnectionException( - f"Connection error for recv at {self._address}" - ) from exc - - async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: - try: - await self.send(message) - # TODO async has issue with poller, after 3-4 calls, it returns empty - # await self.poll(timeout or self._default_timeout) - return await self.recv() - except zmqerr.Again as exc: - raise ConnectionException( - f"Conection error for request at {self._address}" - ) from exc diff --git a/zero/zeromq_patterns/queue_device/worker.py b/zero/zeromq_patterns/queue_device/worker.py index 9d50aed..ebed4d5 100644 --- a/zero/zeromq_patterns/queue_device/worker.py +++ b/zero/zeromq_patterns/queue_device/worker.py @@ -41,18 +41,19 @@ def _recv_and_process(self, msg_handler: Callable[[bytes, bytes], Optional[bytes # first 32 bytes is request id req_id = data[:32] + data = data[32:] - # then 120 bytes is function name - func_name = data[32:152].strip() + # then 80 bytes is function name + func_name = data[:80].strip() # the rest is message - message = data[152:] + message = data[80:] response = msg_handler(func_name, message) - # TODO send is slow, need to find a way to make it faster + # send is slow, need to find a way to make it faster self.socket.send_multipart( - [ident, req_id + response if response else b""], copy=False + [ident, req_id + response if response else b""], zmq.NOBLOCK ) def close(self) -> None: