From ada4305492763a9a53da7e8cc7338fbba850dfa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radek=20Je=C5=BEek?= Date: Thu, 1 Feb 2024 15:40:24 +0100 Subject: [PATCH] fix: make SharedResource threadsafe MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Radek Ježek --- src/genai/_utils/shared_instance.py | 22 +++++++++++++--------- tests/unit/utils/test_async_executor.py | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/genai/_utils/shared_instance.py b/src/genai/_utils/shared_instance.py index 268e0c24..084cdba3 100644 --- a/src/genai/_utils/shared_instance.py +++ b/src/genai/_utils/shared_instance.py @@ -1,3 +1,4 @@ +import threading from abc import abstractmethod from contextlib import AbstractAsyncContextManager, AbstractContextManager from typing import Generic, Optional, TypeVar @@ -16,6 +17,7 @@ class SharedResource(Generic[T], AbstractContextManager): def __init__(self): self._ref_count = 0 self._resource: Optional[T] = None + self._lock = threading.Lock() @abstractmethod def _enter(self) -> T: @@ -35,18 +37,20 @@ def _exit(self) -> None: raise NotImplementedError def __enter__(self) -> T: - self._ref_count += 1 - if self._ref_count == 1: - self._resource = self._enter() + with self._lock: + self._ref_count += 1 + if self._ref_count == 1: + self._resource = self._enter() - assert self._resource - return self._resource + assert self._resource + return self._resource def __exit__(self, exc_type, exc_val, exc_tb): - self._ref_count -= 1 - if self._ref_count == 0: - self._exit() - self._resource = None + with self._lock: + self._ref_count -= 1 + if self._ref_count == 0: + self._exit() + self._resource = None class AsyncSharedResource(Generic[T], AbstractAsyncContextManager): diff --git a/tests/unit/utils/test_async_executor.py b/tests/unit/utils/test_async_executor.py index f9c350a4..50b8bf33 100644 --- a/tests/unit/utils/test_async_executor.py +++ b/tests/unit/utils/test_async_executor.py @@ -1,3 +1,4 @@ +import asyncio import logging from asyncio import sleep from unittest.mock import Mock @@ -109,3 +110,25 @@ async def handler(input: str, *args) -> str: def test_execute_empty_inputs(self): for _ in execute_async(inputs=[], handler=Mock(), http_client=Mock(), throw_on_error=True): ... + + @pytest.mark.asyncio + async def test_async_executor_can_be_used_in_async_context(self, http_client): + """Async executor can be used in asyncio event loop using asyncio.to_thread""" + + def _execute(input: str): + return list( + execute_async( + inputs=[input], + handler=self.get_handler([input]), + http_client=lambda: AsyncHttpxClient(), + throw_on_error=True, + ordered=True, + limiters=[LoopBoundLimiter(lambda: LocalLimiter(limit=10))], + ) + )[0] + + inputs = ["Hello", "World", "here", "are", "some", "inputs"] * 50 + tasks = [asyncio.to_thread(_execute, input) for input in inputs] + results = await asyncio.gather(*tasks) + + assert results == inputs