Skip to content

Commit

Permalink
optimize batch collate
Browse files Browse the repository at this point in the history
  • Loading branch information
aniketmaurya committed Jun 19, 2024
1 parent 5995b0a commit 7bffe8d
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import time
import os
import shutil
from typing import Sequence, Optional, Union, List
from typing import Sequence, Optional, Union, List, Dict
import uuid

from fastapi import FastAPI, Depends, HTTPException, BackgroundTasks, Request, Response
Expand Down Expand Up @@ -68,16 +68,28 @@ def get_batch_from_uid(uids, lit_api, request_buffer):
return batches


def collate_requests(lit_api, request_queue: Queue, request_buffer, max_batch_size, batch_timeout):
def collate_requests(
lit_api: LitAPI, request_queue: Queue, request_buffer: Dict, max_batch_size: int, batch_timeout: float
) -> Optional[List]:
uids = []
entered_at = time.time()
while (batch_timeout - (time.time() - entered_at) > 0) and len(uids) < max_batch_size:
end_time = entered_at + batch_timeout

while time.time() < end_time and len(uids) < max_batch_size:
remaining_time = end_time - time.time()
if remaining_time <= 0:
break

try:
uid = request_queue.get(timeout=0.001)
uid = request_queue.get(timeout=min(remaining_time, 0.001))
uids.append(uid)
except (Empty, ValueError):
except Empty:
continue
return get_batch_from_uid(uids, lit_api, request_buffer)

if uids:
return get_batch_from_uid(uids, lit_api, request_buffer)

return None


def run_batched_loop(lit_api, lit_spec, request_queue: Queue, request_buffer, max_batch_size, batch_timeout):
Expand Down Expand Up @@ -470,12 +482,13 @@ def get_from_pipe(self, read):

async def data_reader(self, read):
data_available = asyncio.Event()
asyncio.get_event_loop().add_reader(read.fileno(), data_available.set)
loop = asyncio.get_event_loop()
loop.add_reader(read.fileno(), data_available.set)

if not read.poll():
await data_available.wait()
data_available.clear()
asyncio.get_event_loop().remove_reader(read.fileno())
loop.remove_reader(read.fileno())
return read.recv()

async def win_data_streamer(self, read, write):
Expand Down

0 comments on commit 7bffe8d

Please sign in to comment.