Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/backend/server/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import time
from typing import Dict

import aiohttp
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.concurrency import iterate_in_threadpool

from backend.server.constants import NODE_STATUS_AVAILABLE
from common.request_metrics import get_request_metrics
from parallax_utils.logging_config import get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -38,6 +40,7 @@ def get_stub(self, node_id):
return self.stubs[node_id]

async def _forward_request(self, request_data: Dict, request_id: str, received_ts: int):
start_time = time.time()
logger.debug(f"Forwarding request {request_id}; stream={request_data.get('stream', False)}")
if (
self.scheduler_manage is None
Expand Down Expand Up @@ -100,11 +103,34 @@ async def _forward_request(self, request_data: Dict, request_id: str, received_t

async def stream_generator():
response = stub.chat_completion(request_data)
first_token_time = None
last_chunk = None
last_token_time = None
try:
iterator = iterate_in_threadpool(response)
async for chunk in iterator:
last_token_time = time.time()
if first_token_time is None:
first_token_time = last_token_time
if chunk is not None and not chunk.decode("utf-8").startswith(
"data: [DONE]"
):
last_chunk = chunk
yield chunk
finally:
if last_chunk is not None:
tps, ttft, input_tokens, output_tokens = get_request_metrics(
last_chunk, start_time, first_token_time, last_token_time
)
if (
tps is not None
and ttft is not None
and input_tokens is not None
and output_tokens is not None
):
logger.info(
f"Request ID: {request_id} | TPS: {tps:.2f} | TTFT: {ttft} ms | Output tokens: {output_tokens} | Input tokens: {input_tokens}"
)
logger.debug(f"client disconnected for {request_id}")
response.cancel()

Expand Down
19 changes: 19 additions & 0 deletions src/common/request_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import json


def get_request_metrics(chunk, start_time, first_token_time, last_token_time):
try:
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if isinstance(chunk, str):
chunk = chunk.removeprefix("data: ").lstrip()
chunk = json.loads(chunk)
usage = chunk.get("usage")
input_tokens = usage.get("prompt_tokens")
output_tokens = usage.get("completion_tokens")
usage.get("total_tokens")
tps = output_tokens / (last_token_time - first_token_time)
ttft = int((first_token_time - start_time) * 1000)
return tps, ttft, input_tokens, output_tokens
except Exception:
return None, None, None, None