Skip to content

Commit

Permalink
feat(generate): cleanup async generator on termination
Browse files Browse the repository at this point in the history
Ref: #180

Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Sep 28, 2023
1 parent 9ccf894 commit eb82a58
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 12 deletions.
51 changes: 39 additions & 12 deletions src/genai/services/async_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from signal import SIGINT, SIGTERM, signal

from genai.exceptions import GenAiException
from genai.options import Options
from genai.schemas.responses import GenerateResponse, TokenizeResponse
from genai.services.connection_manager import ConnectionManager
from genai.utils.errors import to_genai_error

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -47,12 +49,16 @@ def __init__(
self.ordered = ordered
self.options = options
self.throw_on_error = throw_on_error
self.tokenize_client_close_fn_ = None
self.generate_client_close_fn_ = None
self._is_terminating = False

def __enter__(self):
self.accumulator = []
self._initialize_fn_specific_params()
self.queue_ = Queue()
self.loop_ = asyncio.new_event_loop()
self._is_terminating = False
return self

def _shutdown(self):
Expand All @@ -71,7 +77,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self._shutdown()
except Exception as e:
logger.error(str(e))
raise GenAiException(e)
raise to_genai_error(e)

def _initialize_fn_specific_params(self):
if self.fn == "generate":
Expand All @@ -81,7 +87,7 @@ def _initialize_fn_specific_params(self):
self.service_fn_ = self.service.async_generate
self.max_active_tasks_ = ConnectionManager.MAX_CONCURRENT_GENERATE
ConnectionManager.make_generate_client()
self.client_close_fn_ = ConnectionManager.delete_generate_client
self.generate_client_close_fn_ = ConnectionManager.delete_generate_client
elif self.fn == "tokenize":
self.batch_size_ = 5
a, b = divmod(len(self.prompts), self.batch_size_)
Expand All @@ -90,7 +96,7 @@ def _initialize_fn_specific_params(self):
self.service_fn_ = self.service.async_tokenize
self.max_active_tasks_ = ConnectionManager.MAX_REQ_PER_SECOND_TOKENIZE
ConnectionManager.make_tokenize_client()
self.client_close_fn_ = ConnectionManager.delete_tokenize_client
self.tokenize_client_close_fn_ = ConnectionManager.delete_tokenize_client

def _generate_batch(self):
for i in range(0, len(self.prompts), self.batch_size_):
Expand Down Expand Up @@ -119,6 +125,17 @@ async def _get_response_json(self, model, inputs, params, options):

async def _task(self, inputs, batch_num):
async with self.semaphore_:
if self._is_terminating:
self.queue_.put_nowait(
(
batch_num,
len(inputs),
None,
GenAiException("Generation has been aborted by the user."),
)
)
return

response = None
try:
response = await self._get_response_json(self.model_id, inputs, self.params, self.options)
Expand All @@ -133,14 +150,7 @@ async def _task(self, inputs, batch_num):
str(e), response, inputs
)
)
self.queue_.put_nowait(
(
batch_num,
len(inputs),
None,
GenAiException(e),
)
)
self.queue_.put_nowait((batch_num, len(inputs), None, to_genai_error(e)))
return
try:
self.queue_.put_nowait((batch_num, len(inputs), response, None))
Expand All @@ -166,7 +176,14 @@ async def _schedule_requests(self):
def _request_launcher(self):
asyncio.set_event_loop(self.loop_)
self.loop_.run_until_complete(self._schedule_requests())
self.loop_.run_until_complete(self.client_close_fn_())
self.loop_.run_until_complete(self._cleanup())

async def _cleanup(self):
if self.generate_client_close_fn_ is not None:
await self.generate_client_close_fn_()

if self.tokenize_client_close_fn_ is not None:
await self.tokenize_client_close_fn_()

def generate_response(
self,
Expand All @@ -179,8 +196,18 @@ def generate_response(
"""
if len(self.prompts) == 0:
return

with ThreadPoolExecutor(max_workers=1) as executor:

def init_termination(*args):
self._is_terminating = True
executor.shutdown(cancel_futures=False, wait=False)

signal(SIGTERM, init_termination)
signal(SIGINT, init_termination)

executor.submit(self._request_launcher)

counter = 0
minheap, batch_tracker = [], 0
while counter < self.num_batches_:
Expand Down
10 changes: 10 additions & 0 deletions src/genai/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from genai.exceptions import GenAiException

__all__ = ["to_genai_error"]


def to_genai_error(e: Exception) -> GenAiException:
if isinstance(e, GenAiException):
return e

return GenAiException(e)

0 comments on commit eb82a58

Please sign in to comment.