Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"msgpack>=1.0.7",
"safetensors>=0.5.1",
"huggingface-hub",
"jinja2>=3.1.0",
"numpy>=1.26",
"pyzmq>=25.0",
"psutil>=5.9.5",
Expand Down
74 changes: 53 additions & 21 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import argparse
import time
from http import HTTPStatus
from typing import Any, Dict, List, Optional

import mlx.core as mx
import torch
import zmq
from jinja2 import TemplateError
from mlx_lm.server import convert_chat, process_message_content

from parallax.p2p.message_util import (
Expand Down Expand Up @@ -328,29 +330,31 @@ def _join_requests(self, left_reqs: List[Request], right_reqs: List[Request]):

def recv_requests_from_http(self) -> List[Request]:
"""Receives requests from http frontend"""
if self.tp_rank == 0:
recv_reqs = []
while True:
try:
raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK)
if self.tp_rank != 0:
return None

# Check if this is an abort request
if isinstance(raw_request, dict) and raw_request.get("type") == "abort":
logger.debug(
f"Received abort request from HTTP for request ID: {raw_request.get('rid')}"
)
self.scheduler.cancel_request(raw_request.get("rid"))
else:
# Normal request processing - do tokenization and form InitialRequest
req = self._handle_raw_request(raw_request)
recv_reqs.append(req)
except zmq.ZMQError:
break
except Exception as e:
logger.exception(f"Error receiving http request: {e}")
else:
recv_reqs = None
recv_reqs = []
while True:
try:
raw_request = self.recv_from_ipc_socket.recv_pyobj(zmq.NOBLOCK)

# Check if this is an abort request
if isinstance(raw_request, dict) and raw_request.get("type") == "abort":
logger.debug(
f"Received abort request from HTTP for request ID: {raw_request.get('rid')}"
)
self.scheduler.cancel_request(raw_request.get("rid"))
else:
# Normal request processing - do tokenization and form InitialRequest
req = self._handle_raw_request(raw_request)
recv_reqs.append(req)
except zmq.ZMQError:
break
except Exception as e:
logger.exception(f"Error receiving http request: {e}")
self._notify_http_request_error(raw_request, e)
if recv_reqs:
logger.debug(f"Received {len(recv_reqs)} HTTP requests")
return recv_reqs

def recv_requests_from_peer(self) -> List[Request]:
Expand Down Expand Up @@ -775,6 +779,34 @@ def _handle_raw_request(self, raw_request: Dict):
req.routing_table = raw_request["routing_table"]
return req

def _notify_http_request_error(self, raw_request: Optional[Dict], error: Exception):
"""Best-effort notification to HTTP server when request parsing fails."""
if not hasattr(self, "send_to_ipc_socket") or self.send_to_ipc_socket is None:
return
if not isinstance(raw_request, dict):
return
rid = raw_request.get("rid")
if rid is None:
return

is_template_error = isinstance(error, TemplateError)
status = (
HTTPStatus.BAD_REQUEST
if isinstance(error, ValueError) or is_template_error
else HTTPStatus.INTERNAL_SERVER_ERROR
)
payload = {
"type": "error",
"rid": rid,
"error": str(error),
"error_type": error.__class__.__name__,
"status_code": status.value,
}
try:
self.send_to_ipc_socket.send_pyobj(payload)
except Exception: # pragma: no cover - best effort notification
logger.debug("Failed to send error notification to HTTP handler", exc_info=True)

def _handle_cuda_input_requests(self, requests: List[Request]):
"""
Cuda specialized handle function.
Expand Down
Loading