diff --git a/tests/concurrency/rps_async.py b/tests/concurrency/rps_async.py new file mode 100644 index 0000000..fcfad47 --- /dev/null +++ b/tests/concurrency/rps_async.py @@ -0,0 +1,56 @@ +import asyncio +import random +import time +from concurrent.futures import ProcessPoolExecutor + +from zero import AsyncZeroClient + +async_client = AsyncZeroClient("localhost", 5559) + + +async def task(semaphore, items): + async with semaphore: + try: + res = await async_client.call("sum_async", items) + # print(res) + except Exception as e: + print(e) + + +async def process_tasks(items_chunk): + conc = 8 + semaphore = asyncio.BoundedSemaphore(conc) + tasks = [task(semaphore, items) for items in items_chunk] + await asyncio.gather(*tasks) + await async_client.close() + + +def run_chunk(items_chunk): + asyncio.run(process_tasks(items_chunk)) + + +if __name__ == "__main__": + process_no = 8 + + print("Preparing data...") + + sum_items = [[random.randint(50, 500) for _ in range(10)] for _ in range(100000)] + + # Split sum_items into chunks for each process + chunk_size = len(sum_items) // process_no + items_chunks = [ + sum_items[i : i + chunk_size] for i in range(0, len(sum_items), chunk_size) + ] + + print("Running...") + + start = time.time() + + with ProcessPoolExecutor(max_workers=process_no) as executor: + executor.map(run_chunk, items_chunks) + + end = time.time() + time_taken_ms = 1e3 * (end - start) + + print(f"total time taken: {time_taken_ms} ms") + print(f"requests per second: {len(sum_items) / time_taken_ms * 1e3}") diff --git a/tests/concurrency/rps_sync.py b/tests/concurrency/rps_sync.py new file mode 100644 index 0000000..454545b --- /dev/null +++ b/tests/concurrency/rps_sync.py @@ -0,0 +1,37 @@ +import random +import time +from functools import partial +from multiprocessing.pool import Pool + +from zero import ZeroClient + +client = ZeroClient("localhost", 5559) + + +sum_func = partial(client.call, "sum_sync") + + +def get_and_sum(msg): + resp = sum_func(msg) + # print(resp) + + +if __name__ == "__main__": + process_no = 32 + pool = Pool(process_no) + + sum_items = [] + for _ in range(100000): + sum_items.append([random.randint(50, 500) for _ in range(10)]) + + start = time.time() + + res = pool.map_async(get_and_sum, sum_items) + pool.close() + pool.join() + + end = time.time() + time_taken_ms = 1e3 * (end - start) + + print(f"total time taken: {time_taken_ms} ms") + print(f"requests per second: {len(sum_items) / time_taken_ms * 1e3}") diff --git a/tests/concurrency/rps_sync_graph.py b/tests/concurrency/rps_sync_graph.py new file mode 100644 index 0000000..464d454 --- /dev/null +++ b/tests/concurrency/rps_sync_graph.py @@ -0,0 +1,44 @@ +import random +import time +from functools import partial +from multiprocessing.pool import Pool + +import matplotlib.pyplot as plt + +from zero import ZeroClient + +client = ZeroClient("localhost", 5559) + +sum_func = partial(client.call, "sum_sync") + + +def get_and_sum(msg): + return sum_func(msg) + + +if __name__ == "__main__": + + def run_task(process_no): + sum_items = [ + [random.randint(50, 500) for _ in range(10)] for _ in range(100000) + ] + + start = time.time() + with Pool(process_no) as pool: + pool.map_async(get_and_sum, sum_items) + pool.close() + pool.join() + end = time.time() + + time_taken_ms = 1e3 * (end - start) + requests_per_second = len(sum_items) / time_taken_ms * 1e3 + return requests_per_second + + process_numbers = range(2, 128, 2) # From 4 to 128, stepping by 4 + results = [run_task(pn) for pn in process_numbers] + + plt.plot(process_numbers, results) + plt.xlabel("Number of Processes") + plt.ylabel("Requests per Second") + plt.title("Performance by Number of Processes") + plt.show() diff --git a/tests/concurrency/server.py b/tests/concurrency/server.py index 28af43b..ba2baf4 100644 --- a/tests/concurrency/server.py +++ b/tests/concurrency/server.py @@ -22,5 +22,15 @@ async def sleep_async(msg: int) -> str: return f"slept for {msg} msecs" +@app.register_rpc +def sum_sync(msg: list) -> int: + return sum(msg) + + +@app.register_rpc +async def sum_async(msg: list) -> int: + return sum(msg) + + if __name__ == "__main__": - app.run(workers=8) + app.run(workers=4) diff --git a/tests/concurrency/async.py b/tests/concurrency/sleep_test_async.py similarity index 100% rename from tests/concurrency/async.py rename to tests/concurrency/sleep_test_async.py diff --git a/tests/concurrency/sync.py b/tests/concurrency/sleep_test_sync.py similarity index 83% rename from tests/concurrency/sync.py rename to tests/concurrency/sleep_test_sync.py index aa31403..057205d 100644 --- a/tests/concurrency/sync.py +++ b/tests/concurrency/sleep_test_sync.py @@ -1,3 +1,8 @@ +""" +This test ensures that the sync client can handle multiple requests concurrently. +And it doesn't mix up the responses. +""" + import random import time from functools import partial @@ -8,7 +13,7 @@ client = ZeroClient("localhost", 5559) -func = partial(client.call, "sleep") +func = partial(client.call, "sleep_async") def get_and_print(msg): @@ -19,7 +24,7 @@ def get_and_print(msg): if __name__ == "__main__": - process_no = 10 + process_no = 32 pool = Pool(process_no) sleep_times = [] diff --git a/zero/client_server/client.py b/zero/client_server/client.py index 476e9de..57db4cc 100644 --- a/zero/client_server/client.py +++ b/zero/client_server/client.py @@ -112,6 +112,7 @@ def call( _timeout = self._default_timeout if timeout is None else timeout def _poll_data(): + # TODO poll is slow, need to find a better way if not zmqc.poll(_timeout): raise TimeoutException( f"Timeout while sending message at {self._address}" @@ -249,7 +250,7 @@ async def call( async def _poll_data(): # TODO async has issue with poller, after 3-4 calls, it returns empty - # if not await self.zmq_client.poll(_timeout): + # if not await zmqc.poll(_timeout): # raise TimeoutException(f"Timeout while sending message at {self._address}") resp = await zmqc.recv() @@ -322,6 +323,7 @@ def __init__( def get(self) -> ZeroMQClient: thread_id = threading.get_ident() if thread_id not in self._pool: + logging.debug("No connection found in current thread, creating new one") self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout) self._pool[thread_id].connect(self._address) self._try_connect_ping(self._pool[thread_id]) @@ -360,6 +362,7 @@ def __init__( async def get(self) -> AsyncZeroMQClient: thread_id = threading.get_ident() if thread_id not in self._pool: + logging.debug("No connection found in current thread, creating new one") self._pool[thread_id] = get_async_client( config.ZEROMQ_PATTERN, self._timeout ) diff --git a/zero/client_server/worker.py b/zero/client_server/worker.py index 6e485df..e8d8bf2 100644 --- a/zero/client_server/worker.py +++ b/zero/client_server/worker.py @@ -8,6 +8,7 @@ from zero.codegen.codegen import CodeGen from zero.encoder.protocols import Encoder from zero.error import SERVER_PROCESSING_ERROR +from zero.utils.async_to_sync import async_to_sync from zero.zero_mq.factory import get_worker @@ -80,7 +81,8 @@ def handle_msg(self, rpc, msg): # TODO: is this a bottleneck if inspect.iscoroutinefunction(func): # this is blocking - ret = self._loop.run_until_complete(func(msg) if msg else func()) + # ret = self._loop.run_until_complete(func(msg) if msg else func()) + ret = async_to_sync(func)(msg) if msg else async_to_sync(func)() else: ret = func(msg) if msg else func() diff --git a/zero/utils/async_to_sync.py b/zero/utils/async_to_sync.py new file mode 100644 index 0000000..770986e --- /dev/null +++ b/zero/utils/async_to_sync.py @@ -0,0 +1,30 @@ +import asyncio +import threading +from functools import wraps + +_loop = None +_thrd = None + + +def start_async_loop(): + global _loop, _thrd + if _loop is None or _thrd is None or not _thrd.is_alive(): + _loop = asyncio.new_event_loop() + _thrd = threading.Thread( + target=_loop.run_forever, name="Async Runner", daemon=True + ) + _thrd.start() + + +def async_to_sync(func): + @wraps(func) + def run(*args, **kwargs): + start_async_loop() # Ensure the loop and thread are started + try: + future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _loop) + return future.result() + except Exception as e: + print(f"Exception occurred: {e}") + raise + + return run diff --git a/zero/zero_mq/queue_device/client.py b/zero/zero_mq/queue_device/client.py index ccecc71..1fe70ac 100644 --- a/zero/zero_mq/queue_device/client.py +++ b/zero/zero_mq/queue_device/client.py @@ -3,7 +3,8 @@ from typing import Optional import zmq -import zmq.asyncio +import zmq.asyncio as zmqasync +import zmq.error as zmqerr from zero.error import ConnectionException, TimeoutException @@ -36,7 +37,7 @@ def close(self) -> None: def send(self, message: bytes) -> None: try: self.socket.send(message, zmq.DONTWAIT) - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Connection error for send at {self._address}" ) from exc @@ -48,7 +49,7 @@ def poll(self, timeout: int) -> bool: def recv(self) -> bytes: try: return self.socket.recv() - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Connection error for recv at {self._address}" ) from exc @@ -59,7 +60,7 @@ def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: if self.poll(timeout or self._default_timeout): return self.recv() raise TimeoutException(f"Timeout waiting for response from {self._address}") - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Connection error for request at {self._address}" ) from exc @@ -73,14 +74,14 @@ def __init__(self, default_timeout: int = 2000): self._address: str = None # type: ignore self._default_timeout = default_timeout - self._context = zmq.asyncio.Context.instance() + self._context = zmqasync.Context.instance() - self.socket: zmq.asyncio.Socket = self._context.socket(zmq.DEALER) + self.socket: zmqasync.Socket = self._context.socket(zmq.DEALER) self.socket.setsockopt(zmq.LINGER, 0) # dont buffer messages self.socket.setsockopt(zmq.RCVTIMEO, default_timeout) self.socket.setsockopt(zmq.SNDTIMEO, default_timeout) - self.poller = zmq.asyncio.Poller() + self.poller = zmqasync.Poller() self.poller.register(self.socket, zmq.POLLIN) @property @@ -97,7 +98,7 @@ def close(self) -> None: async def send(self, message: bytes) -> None: try: await self.socket.send(message, zmq.DONTWAIT) - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Connection error for send at {self._address}" ) from exc @@ -109,7 +110,7 @@ async def poll(self, timeout: int) -> bool: async def recv(self) -> bytes: try: return await self.socket.recv() # type: ignore - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Connection error for recv at {self._address}" ) from exc @@ -120,7 +121,7 @@ async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes: # TODO async has issue with poller, after 3-4 calls, it returns empty # await self.poll(timeout or self._default_timeout) return await self.recv() - except zmq.error.Again as exc: + except zmqerr.Again as exc: raise ConnectionException( f"Conection error for request at {self._address}" ) from exc diff --git a/zero/zero_mq/queue_device/worker.py b/zero/zero_mq/queue_device/worker.py index bc92f34..5377adc 100644 --- a/zero/zero_mq/queue_device/worker.py +++ b/zero/zero_mq/queue_device/worker.py @@ -36,6 +36,8 @@ def _recv_and_process(self, msg_handler: Callable[[bytes], Optional[bytes]]): ident, message = frames response = msg_handler(message) + + # TODO send is slow, need to find a way to make it faster self.socket.send_multipart([ident, response], zmq.NOBLOCK) def close(self) -> None: