Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/concurrency/rps_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import asyncio
import random
import time
from concurrent.futures import ProcessPoolExecutor

from zero import AsyncZeroClient

async_client = AsyncZeroClient("localhost", 5559)


async def task(semaphore, items):
async with semaphore:
try:
res = await async_client.call("sum_async", items)
# print(res)
except Exception as e:
print(e)


async def process_tasks(items_chunk):
conc = 8
semaphore = asyncio.BoundedSemaphore(conc)
tasks = [task(semaphore, items) for items in items_chunk]
await asyncio.gather(*tasks)
await async_client.close()


def run_chunk(items_chunk):
asyncio.run(process_tasks(items_chunk))


if __name__ == "__main__":
process_no = 8

print("Preparing data...")

sum_items = [[random.randint(50, 500) for _ in range(10)] for _ in range(100000)]

# Split sum_items into chunks for each process
chunk_size = len(sum_items) // process_no
items_chunks = [
sum_items[i : i + chunk_size] for i in range(0, len(sum_items), chunk_size)
]

print("Running...")

start = time.time()

with ProcessPoolExecutor(max_workers=process_no) as executor:
executor.map(run_chunk, items_chunks)

end = time.time()
time_taken_ms = 1e3 * (end - start)

print(f"total time taken: {time_taken_ms} ms")
print(f"requests per second: {len(sum_items) / time_taken_ms * 1e3}")
37 changes: 37 additions & 0 deletions tests/concurrency/rps_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import random
import time
from functools import partial
from multiprocessing.pool import Pool

from zero import ZeroClient

client = ZeroClient("localhost", 5559)


sum_func = partial(client.call, "sum_sync")


def get_and_sum(msg):
resp = sum_func(msg)
# print(resp)


if __name__ == "__main__":
process_no = 32
pool = Pool(process_no)

sum_items = []
for _ in range(100000):
sum_items.append([random.randint(50, 500) for _ in range(10)])

start = time.time()

res = pool.map_async(get_and_sum, sum_items)
pool.close()
pool.join()

end = time.time()
time_taken_ms = 1e3 * (end - start)

print(f"total time taken: {time_taken_ms} ms")
print(f"requests per second: {len(sum_items) / time_taken_ms * 1e3}")
44 changes: 44 additions & 0 deletions tests/concurrency/rps_sync_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import random
import time
from functools import partial
from multiprocessing.pool import Pool

import matplotlib.pyplot as plt

from zero import ZeroClient

client = ZeroClient("localhost", 5559)

sum_func = partial(client.call, "sum_sync")


def get_and_sum(msg):
return sum_func(msg)


if __name__ == "__main__":

def run_task(process_no):
sum_items = [
[random.randint(50, 500) for _ in range(10)] for _ in range(100000)
]

start = time.time()
with Pool(process_no) as pool:
pool.map_async(get_and_sum, sum_items)
pool.close()
pool.join()
end = time.time()

time_taken_ms = 1e3 * (end - start)
requests_per_second = len(sum_items) / time_taken_ms * 1e3
return requests_per_second

process_numbers = range(2, 128, 2) # From 4 to 128, stepping by 4
results = [run_task(pn) for pn in process_numbers]

plt.plot(process_numbers, results)
plt.xlabel("Number of Processes")
plt.ylabel("Requests per Second")
plt.title("Performance by Number of Processes")
plt.show()
12 changes: 11 additions & 1 deletion tests/concurrency/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,15 @@ async def sleep_async(msg: int) -> str:
return f"slept for {msg} msecs"


@app.register_rpc
def sum_sync(msg: list) -> int:
return sum(msg)


@app.register_rpc
async def sum_async(msg: list) -> int:
return sum(msg)


if __name__ == "__main__":
app.run(workers=8)
app.run(workers=4)
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
"""
This test ensures that the sync client can handle multiple requests concurrently.
And it doesn't mix up the responses.
"""

import random
import time
from functools import partial
Expand All @@ -8,7 +13,7 @@
client = ZeroClient("localhost", 5559)


func = partial(client.call, "sleep")
func = partial(client.call, "sleep_async")


def get_and_print(msg):
Expand All @@ -19,7 +24,7 @@ def get_and_print(msg):


if __name__ == "__main__":
process_no = 10
process_no = 32
pool = Pool(process_no)

sleep_times = []
Expand Down
5 changes: 4 additions & 1 deletion zero/client_server/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def call(
_timeout = self._default_timeout if timeout is None else timeout

def _poll_data():
# TODO poll is slow, need to find a better way
if not zmqc.poll(_timeout):
raise TimeoutException(
f"Timeout while sending message at {self._address}"
Expand Down Expand Up @@ -249,7 +250,7 @@ async def call(

async def _poll_data():
# TODO async has issue with poller, after 3-4 calls, it returns empty
# if not await self.zmq_client.poll(_timeout):
# if not await zmqc.poll(_timeout):
# raise TimeoutException(f"Timeout while sending message at {self._address}")

resp = await zmqc.recv()
Expand Down Expand Up @@ -322,6 +323,7 @@ def __init__(
def get(self) -> ZeroMQClient:
thread_id = threading.get_ident()
if thread_id not in self._pool:
logging.debug("No connection found in current thread, creating new one")
self._pool[thread_id] = get_client(config.ZEROMQ_PATTERN, self._timeout)
self._pool[thread_id].connect(self._address)
self._try_connect_ping(self._pool[thread_id])
Expand Down Expand Up @@ -360,6 +362,7 @@ def __init__(
async def get(self) -> AsyncZeroMQClient:
thread_id = threading.get_ident()
if thread_id not in self._pool:
logging.debug("No connection found in current thread, creating new one")
self._pool[thread_id] = get_async_client(
config.ZEROMQ_PATTERN, self._timeout
)
Expand Down
4 changes: 3 additions & 1 deletion zero/client_server/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from zero.codegen.codegen import CodeGen
from zero.encoder.protocols import Encoder
from zero.error import SERVER_PROCESSING_ERROR
from zero.utils.async_to_sync import async_to_sync
from zero.zero_mq.factory import get_worker


Expand Down Expand Up @@ -80,7 +81,8 @@ def handle_msg(self, rpc, msg):
# TODO: is this a bottleneck
if inspect.iscoroutinefunction(func):
# this is blocking
ret = self._loop.run_until_complete(func(msg) if msg else func())
# ret = self._loop.run_until_complete(func(msg) if msg else func())
ret = async_to_sync(func)(msg) if msg else async_to_sync(func)()
else:
ret = func(msg) if msg else func()

Expand Down
30 changes: 30 additions & 0 deletions zero/utils/async_to_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import asyncio
import threading
from functools import wraps

_loop = None
_thrd = None


def start_async_loop():
global _loop, _thrd
if _loop is None or _thrd is None or not _thrd.is_alive():
_loop = asyncio.new_event_loop()
_thrd = threading.Thread(
target=_loop.run_forever, name="Async Runner", daemon=True
)
_thrd.start()


def async_to_sync(func):
@wraps(func)
def run(*args, **kwargs):
start_async_loop() # Ensure the loop and thread are started
try:
future = asyncio.run_coroutine_threadsafe(func(*args, **kwargs), _loop)
return future.result()
except Exception as e:
print(f"Exception occurred: {e}")
raise

return run
21 changes: 11 additions & 10 deletions zero/zero_mq/queue_device/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from typing import Optional

import zmq
import zmq.asyncio
import zmq.asyncio as zmqasync
import zmq.error as zmqerr

from zero.error import ConnectionException, TimeoutException

Expand Down Expand Up @@ -36,7 +37,7 @@ def close(self) -> None:
def send(self, message: bytes) -> None:
try:
self.socket.send(message, zmq.DONTWAIT)
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for send at {self._address}"
) from exc
Expand All @@ -48,7 +49,7 @@ def poll(self, timeout: int) -> bool:
def recv(self) -> bytes:
try:
return self.socket.recv()
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for recv at {self._address}"
) from exc
Expand All @@ -59,7 +60,7 @@ def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
if self.poll(timeout or self._default_timeout):
return self.recv()
raise TimeoutException(f"Timeout waiting for response from {self._address}")
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for request at {self._address}"
) from exc
Expand All @@ -73,14 +74,14 @@ def __init__(self, default_timeout: int = 2000):

self._address: str = None # type: ignore
self._default_timeout = default_timeout
self._context = zmq.asyncio.Context.instance()
self._context = zmqasync.Context.instance()

self.socket: zmq.asyncio.Socket = self._context.socket(zmq.DEALER)
self.socket: zmqasync.Socket = self._context.socket(zmq.DEALER)
self.socket.setsockopt(zmq.LINGER, 0) # dont buffer messages
self.socket.setsockopt(zmq.RCVTIMEO, default_timeout)
self.socket.setsockopt(zmq.SNDTIMEO, default_timeout)

self.poller = zmq.asyncio.Poller()
self.poller = zmqasync.Poller()
self.poller.register(self.socket, zmq.POLLIN)

@property
Expand All @@ -97,7 +98,7 @@ def close(self) -> None:
async def send(self, message: bytes) -> None:
try:
await self.socket.send(message, zmq.DONTWAIT)
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for send at {self._address}"
) from exc
Expand All @@ -109,7 +110,7 @@ async def poll(self, timeout: int) -> bool:
async def recv(self) -> bytes:
try:
return await self.socket.recv() # type: ignore
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Connection error for recv at {self._address}"
) from exc
Expand All @@ -120,7 +121,7 @@ async def request(self, message: bytes, timeout: Optional[int] = None) -> bytes:
# TODO async has issue with poller, after 3-4 calls, it returns empty
# await self.poll(timeout or self._default_timeout)
return await self.recv()
except zmq.error.Again as exc:
except zmqerr.Again as exc:
raise ConnectionException(
f"Conection error for request at {self._address}"
) from exc
2 changes: 2 additions & 0 deletions zero/zero_mq/queue_device/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def _recv_and_process(self, msg_handler: Callable[[bytes], Optional[bytes]]):

ident, message = frames
response = msg_handler(message)

# TODO send is slow, need to find a way to make it faster
self.socket.send_multipart([ident, response], zmq.NOBLOCK)

def close(self) -> None:
Expand Down