diff --git a/src/backend/server/request_handler.py b/src/backend/server/request_handler.py index 37dad7d2..f3345c33 100644 --- a/src/backend/server/request_handler.py +++ b/src/backend/server/request_handler.py @@ -1,4 +1,5 @@ import json +import time from typing import Dict import aiohttp @@ -6,6 +7,7 @@ from starlette.concurrency import iterate_in_threadpool from backend.server.constants import NODE_STATUS_AVAILABLE +from common.request_metrics import get_request_metrics from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -38,6 +40,7 @@ def get_stub(self, node_id): return self.stubs[node_id] async def _forward_request(self, request_data: Dict, request_id: str, received_ts: int): + start_time = time.time() logger.debug(f"Forwarding request {request_id}; stream={request_data.get('stream', False)}") if ( self.scheduler_manage is None @@ -100,11 +103,34 @@ async def _forward_request(self, request_data: Dict, request_id: str, received_t async def stream_generator(): response = stub.chat_completion(request_data) + first_token_time = None + last_chunk = None + last_token_time = None try: iterator = iterate_in_threadpool(response) async for chunk in iterator: + last_token_time = time.time() + if first_token_time is None: + first_token_time = last_token_time + if chunk is not None and not chunk.decode("utf-8").startswith( + "data: [DONE]" + ): + last_chunk = chunk yield chunk finally: + if last_chunk is not None: + tps, ttft, input_tokens, output_tokens = get_request_metrics( + last_chunk, start_time, first_token_time, last_token_time + ) + if ( + tps is not None + and ttft is not None + and input_tokens is not None + and output_tokens is not None + ): + logger.info( + f"Request ID: {request_id} | TPS: {tps:.2f} | TTFT: {ttft} ms | Output tokens: {output_tokens} | Input tokens: {input_tokens}" + ) logger.debug(f"client disconnected for {request_id}") response.cancel() diff --git a/src/common/request_metrics.py b/src/common/request_metrics.py new file mode 100644 index 00000000..51cff4dc --- /dev/null +++ b/src/common/request_metrics.py @@ -0,0 +1,19 @@ +import json + + +def get_request_metrics(chunk, start_time, first_token_time, last_token_time): + try: + if isinstance(chunk, bytes): + chunk = chunk.decode("utf-8") + if isinstance(chunk, str): + chunk = chunk.removeprefix("data: ").lstrip() + chunk = json.loads(chunk) + usage = chunk.get("usage") + input_tokens = usage.get("prompt_tokens") + output_tokens = usage.get("completion_tokens") + usage.get("total_tokens") + tps = output_tokens / (last_token_time - first_token_time) + ttft = int((first_token_time - start_time) * 1000) + return tps, ttft, input_tokens, output_tokens + except Exception: + return None, None, None, None