diff --git a/pyproject.toml b/pyproject.toml index fdc6d91f..a53887e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "pydantic", "protobuf==6.31.1", "dijkstar==2.6.0", - "lattica==1.0.4", + "lattica==1.0.6", ] [project.scripts] diff --git a/src/backend/server/request_handler.py b/src/backend/server/request_handler.py index 36354bcb..b32b188a 100644 --- a/src/backend/server/request_handler.py +++ b/src/backend/server/request_handler.py @@ -3,6 +3,7 @@ import aiohttp from fastapi.responses import JSONResponse, StreamingResponse +from starlette.concurrency import iterate_in_threadpool from backend.server.constants import NODE_STATUS_AVAILABLE from parallax_utils.logging_config import get_logger @@ -92,29 +93,41 @@ async def _forward_request(self, request_data: Dict, request_id: str, received_t request_data["routing_table"] = routing_table stub = self.get_stub(routing_table[0]) is_stream = request_data.get("stream", False) - - if is_stream: - - def stream_generator(): - for chunk in stub.chat_completion(request_data): - yield chunk - - resp = StreamingResponse( - stream_generator(), - media_type="text/event-stream", - headers={ - "X-Content-Type-Options": "nosniff", - "Cache-Control": "no-cache", - }, + try: + if is_stream: + + async def stream_generator(): + response = stub.chat_completion(request_data) + try: + iterator = iterate_in_threadpool(response) + async for chunk in iterator: + yield chunk + finally: + logger.debug(f"client disconnected for {request_id}") + response.cancel() + + resp = StreamingResponse( + stream_generator(), + media_type="text/event-stream", + headers={ + "X-Content-Type-Options": "nosniff", + "Cache-Control": "no-cache", + }, + ) + logger.debug(f"Streaming response initiated for {request_id}") + return resp + else: + response = stub.chat_completion(request_data) + response = next(response).decode() + logger.debug(f"Non-stream response completed for {request_id}") + # response is a JSON string; parse to Python object before returning + return JSONResponse(content=json.loads(response)) + except Exception as e: + logger.exception(f"Error in _forward_request: {e}") + return JSONResponse( + content={"error": "Internal server error"}, + status_code=500, ) - logger.debug(f"Streaming response initiated for {request_id}") - return resp - else: - response = stub.chat_completion(request_data) - response = next(response).decode() - logger.debug(f"Non-stream response completed for {request_id}") - # response is a JSON string; parse to Python object before returning - return JSONResponse(content=json.loads(response)) async def v1_chat_completions(self, request_data: Dict, request_id: str, received_ts: int): return await self._forward_request(request_data, request_id, received_ts)