diff --git a/src/backend/main.py b/src/backend/main.py index c72b0a6e..bfd7cb5a 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -95,14 +95,6 @@ async def stream_cluster_status(): ) -@app.post("/v1/completions") -async def openai_v1_completions(raw_request: Request): - request_data = await raw_request.json() - request_id = uuid.uuid4() - received_ts = time.time() - return await request_handler.v1_completions(request_data, request_id, received_ts) - - @app.post("/v1/chat/completions") async def openai_v1_chat_completions(raw_request: Request): request_data = await raw_request.json() @@ -145,6 +137,7 @@ async def serve_index(): f"/ip4/0.0.0.0/udp/{args.udp_port}/quic-v1", ], announce_maddrs=args.announce_maddrs, + http_port=args.port, ) request_handler.set_scheduler_manage(scheduler_manage) diff --git a/src/backend/server/constants.py b/src/backend/server/constants.py index fd4c5f22..74cded7a 100644 --- a/src/backend/server/constants.py +++ b/src/backend/server/constants.py @@ -2,6 +2,7 @@ CLUSTER_STATUS_WAITING = "waiting" CLUSTER_STATUS_AVAILABLE = "available" CLUSTER_STATUS_REBALANCING = "rebalancing" +CLUSTER_STATUS_FAILED = "failed" # Node status constants NODE_STATUS_WAITING = "waiting" diff --git a/src/backend/server/rpc_connection_handler.py b/src/backend/server/rpc_connection_handler.py index 50185d86..0fee9922 100644 --- a/src/backend/server/rpc_connection_handler.py +++ b/src/backend/server/rpc_connection_handler.py @@ -1,6 +1,6 @@ import time -from lattica import ConnectionHandler, Lattica, rpc_method, rpc_stream +from lattica import ConnectionHandler, Lattica, rpc_method, rpc_stream, rpc_stream_iter from parallax_utils.logging_config import get_logger from scheduling.node import Node, NodeHardwareInfo @@ -8,6 +8,10 @@ logger = get_logger(__name__) +import json + +import httpx + class RPCConnectionHandler(ConnectionHandler): """ @@ -19,10 +23,12 @@ def __init__( self, lattica: Lattica, scheduler: Scheduler, + http_port: int, ): # Initialize the base class super().__init__(lattica) self.scheduler = scheduler + self.http_port = http_port @rpc_stream def node_join(self, message): @@ -79,6 +85,47 @@ def node_update(self, message): logger.exception(f"node_update error: {e}") return {} + @rpc_stream_iter + def chat_completion( + self, + request, + ): + """Handle chat completion request""" + logger.debug(f"Chat completion request: {request}, type: {type(request)}") + try: + with httpx.Client(timeout=10 * 60, proxy=None, trust_env=False) as client: + if request.get("stream", False): + with client.stream( + "POST", + f"http://localhost:{self.http_port}/v1/chat/completions", + json=request, + ) as response: + for chunk in response.iter_bytes(): + if chunk: + yield chunk + else: + response = client.post( + f"http://localhost:{self.http_port}/v1/chat/completions", json=request + ).json() + yield json.dumps(response).encode() + except Exception as e: + logger.exception(f"Error in chat completion: {e}") + yield b"internal server error" + + @rpc_stream_iter + def cluster_status(self): + try: + with httpx.Client(timeout=10 * 60, proxy=None, trust_env=False) as client: + with client.stream( + "GET", f"http://localhost:{self.http_port}/cluster/status" + ) as response: + for chunk in response.iter_bytes(): + if chunk: + yield chunk + except Exception as e: + logger.exception(f"Error in cluster status: {e}") + yield json.dumps({"error": "internal server error"}).encode() + def wait_layer_allocation(self, current_node_id, wait_seconds): start_time = time.time() while True: diff --git a/src/backend/server/scheduler_manage.py b/src/backend/server/scheduler_manage.py index 582a9dc4..bdb78982 100644 --- a/src/backend/server/scheduler_manage.py +++ b/src/backend/server/scheduler_manage.py @@ -31,6 +31,7 @@ def __init__( dht_prefix: str = "gradient", host_maddrs: List[str] = [], announce_maddrs: List[str] = [], + http_port: int = 3001, ): """Initialize the manager with networking bootstrap parameters.""" self.initial_peers = initial_peers @@ -38,7 +39,7 @@ def __init__( self.dht_prefix = dht_prefix self.host_maddrs = host_maddrs self.announce_maddrs = announce_maddrs - + self.http_port = http_port self.model_name = None self.init_nodes_num = None self.scheduler = None @@ -190,6 +191,7 @@ def _start_lattica(self): self.connection_handler = RPCConnectionHandler( lattica=self.lattica, scheduler=self.scheduler, + http_port=self.http_port, ) logger.debug("RPCConnectionHandler initialized") diff --git a/src/frontend/chat.html b/src/frontend/chat.html new file mode 100644 index 00000000..0ad191f8 --- /dev/null +++ b/src/frontend/chat.html @@ -0,0 +1,13 @@ + + +
+ + + +