Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"protobuf==6.31.1",
"dijkstar==2.6.0",
"huggingface-hub",
"lattica==1.0.2",
"lattica==1.0.3",
]

[project.scripts]
Expand Down
69 changes: 19 additions & 50 deletions src/backend/server/request_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Dict

import aiohttp
from fastapi import HTTPException
from fastapi.responses import JSONResponse, StreamingResponse

from backend.server.constants import NODE_STATUS_AVAILABLE
Expand All @@ -26,16 +26,18 @@ class RequestHandler:

def __init__(self):
self.scheduler_manage = None
self.stubs = {}

def set_scheduler_manage(self, scheduler_manage):
self.scheduler_manage = scheduler_manage

async def _forward_request(
self, endpoint: str, request_data: Dict, request_id: str, received_ts: int
):
logger.debug(
f"Forwarding request {request_id} to endpoint {endpoint}; stream={request_data.get('stream', False)}"
)
def get_stub(self, node_id):
if node_id not in self.stubs:
self.stubs[node_id] = self.scheduler_manage.completion_handler.get_stub(node_id)
return self.stubs[node_id]

async def _forward_request(self, request_data: Dict, request_id: str, received_ts: int):
logger.debug(f"Forwarding request {request_id}; stream={request_data.get('stream', False)}")
if (
self.scheduler_manage is None
or not self.scheduler_manage.get_schedule_status() == NODE_STATUS_AVAILABLE
Expand Down Expand Up @@ -88,41 +90,14 @@ async def _forward_request(
)

request_data["routing_table"] = routing_table
call_url = self.scheduler_manage.get_call_url_by_node_id(routing_table[0])
logger.debug(
f"Resolved call_url for request {request_id}: node={routing_table[0]} -> {call_url}"
)

if not call_url:
return JSONResponse(
content={"error": "Call url not found of peer id: " + routing_table[0]},
status_code=500,
)

url = call_url + endpoint
stub = self.get_stub(routing_table[0])
is_stream = request_data.get("stream", False)
logger.debug(f"POST upstream: url={url}, stream={is_stream}")

async def _process_upstream_response(response: aiohttp.ClientResponse):
logger.debug(f"post: {request_id}, code: {response.status}, params: {request_data}")
if response.status != 200:
error_text = await response.text()
error_msg = (
f"Upstream service returned status {response.status}, response: {error_text}"
)
logger.error(f"completions error: {error_msg}, request_id: {request_id}")
raise HTTPException(status_code=response.status, detail=error_msg)

if is_stream:

async def stream_generator():
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with session.post(url, json=request_data) as response:
await _process_upstream_response(response)

async for chunk in response.content:
if chunk:
yield chunk
def stream_generator():
for chunk in stub.chat_completion(request_data):
yield chunk

resp = StreamingResponse(
stream_generator(),
Expand All @@ -135,17 +110,11 @@ async def stream_generator():
logger.debug(f"Streaming response initiated for {request_id}")
return resp
else:
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
async with session.post(url, json=request_data) as response:
await _process_upstream_response(response)
result = await response.json()
logger.debug(f"Non-stream response completed for {request_id}")
return JSONResponse(content=result)

async def v1_completions(self, request_data: Dict, request_id: str, received_ts: int):
return await self._forward_request("/v1/completions", request_data, request_id, received_ts)
response = stub.chat_completion(request_data)
response = next(response).decode()
logger.debug(f"Non-stream response completed for {request_id}")
# response is a JSON string; parse to Python object before returning
return JSONResponse(content=json.loads(response))

async def v1_chat_completions(self, request_data: Dict, request_id: str, received_ts: int):
return await self._forward_request(
"/v1/chat/completions", request_data, request_id, received_ts
)
return await self._forward_request(request_data, request_id, received_ts)
14 changes: 0 additions & 14 deletions src/backend/server/rpc_connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,10 @@ def __init__(
# Initialize the base class
super().__init__(lattica)
self.scheduler = scheduler
self.call_url_map = {}

@rpc_stream
def node_join(self, message):
# node = {
# "http_port": "8000",
# "node_id": "lattica peer id",
# "hardware": {
# "node_id": "lattica peer id",
Expand All @@ -44,14 +42,6 @@ def node_join(self, message):
logger.info(f"receive node_join request: {message}")
try:
node = self.build_node(message)

try:
node_ip = self.lattica_instance.get_peer_addresses(node.node_id)[0].split("/")[2]
logger.info(f"get ip for {node.node_id}: {node_ip}")
except Exception as e:
logger.warning(f"Failed to get ip for {node.node_id}: {e}, using 127.0.0.1")
node_ip = "127.0.0.1"
self.call_url_map[node.node_id] = f"http://{node_ip}:{message.get('http_port')}"
self.scheduler.enqueue_join(node)

response = self.wait_layer_allocation(node.node_id, wait_seconds=300)
Expand All @@ -67,7 +57,6 @@ def node_leave(self, message):
try:
node = self.build_node(message)
self.scheduler.enqueue_leave(node.node_id)
self.call_url_map.pop(node.node_id)
return {}
except Exception as e:
logger.exception(f"node_leave error: {e}")
Expand Down Expand Up @@ -148,6 +137,3 @@ def build_hardware(self, hardware_json):
memory_gb=memory_gb,
memory_bandwidth_gbps=memory_bandwidth_gbps,
)

def get_call_url_by_node_id(self, node_id):
return self.call_url_map.get(node_id, None)
8 changes: 8 additions & 0 deletions src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_model_info,
get_node_join_command,
)
from parallax.p2p.server import TransformerConnectionHandler
from parallax_utils.logging_config import get_logger
from scheduling.node import RequestSignal
from scheduling.scheduler import Scheduler
Expand Down Expand Up @@ -63,6 +64,13 @@ def run(self, model_name, init_nodes_num, is_local_network=True):

self._start_scheduler(model_name, init_nodes_num)
self._start_lattica()
self.completion_handler = TransformerConnectionHandler(
lattica=self.lattica,
recv_from_peer_addr="",
send_to_peer_addr="",
block_start_index=0,
block_end_index=1,
)

def is_running(self):
"""
Expand Down
32 changes: 14 additions & 18 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ def get_project_root():
return Path.cwd()


def get_relay_params():
return [
"--relay-servers",
"/dns4/relay-lattica.gradient.network/udp/18080/quic-v1/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf",
"/dns4/relay-lattica.gradient.network/tcp/18080/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf",
"--initial-peers",
"/dns4/bootstrap-lattica.gradient.network/udp/18080/quic-v1/p2p/12D3KooWJHXvu8TWkFn6hmSwaxdCLy4ZzFwr4u5mvF9Fe2rMmFXb",
"/dns4/bootstrap-lattica.gradient.network/tcp/18080/p2p/12D3KooWJHXvu8TWkFn6hmSwaxdCLy4ZzFwr4u5mvF9Fe2rMmFXb",
]


def run_command(args):
"""Run the scheduler (equivalent to scripts/start.sh)."""
check_python_version()
Expand Down Expand Up @@ -68,12 +79,7 @@ def run_command(args):
if args.init_nodes_num:
cmd.extend(["--init-nodes-num", str(args.init_nodes_num)])
if args.use_relay:
cmd.extend(
[
"--relay-servers",
"/dns4/relay-lattica.gradient.network/udp/18080/quic-v1/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf /dns4/relay-lattica.gradient.network/tcp/18080/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf",
]
)
cmd.extend(get_relay_params())

logger.info(f"Running command: {' '.join(cmd)}")

Expand Down Expand Up @@ -146,18 +152,8 @@ def join_command(args):
if args.use_relay or (
args.scheduler_addr != "auto" and not str(args.scheduler_addr).startswith("/")
):
cmd.extend(
[
"--relay-servers",
"/dns4/relay-lattica.gradient.network/udp/18080/quic-v1/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf /dns4/relay-lattica.gradient.network/tcp/18080/p2p/12D3KooWDaqDAsFupYvffBDxjHHuWmEAJE4sMDCXiuZiB8aG8rjf",
]
)
cmd.extend(
[
"--initial-peers",
"/dns4/bootstrap-lattica.gradient.network/udp/18080/quic-v1/p2p/12D3KooWJHXvu8TWkFn6hmSwaxdCLy4ZzFwr4u5mvF9Fe2rMmFXb /dns4/bootstrap-lattica.gradient.network/tcp/18080/p2p/12D3KooWJHXvu8TWkFn6hmSwaxdCLy4ZzFwr4u5mvF9Fe2rMmFXb",
]
)
logger.info("Using public relay servers")
cmd.extend(get_relay_params())

logger.info(f"Running command: {' '.join(cmd)}")
logger.info(f"Scheduler address: {args.scheduler_addr}")
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
dht_prefix=args.dht_prefix,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
http_port=args.port if args.announce_http_port is None else args.announce_http_port,
http_port=args.port,
notify_url=args.notify_url,
recv_from_peer_addr=args.recv_from_peer_addr,
send_to_peer_addr=args.send_to_peer_addr,
Expand All @@ -97,7 +97,7 @@
dht_prefix=args.dht_prefix,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
http_port=args.port if args.announce_http_port is None else args.announce_http_port,
http_port=args.port,
notify_url=args.notify_url,
recv_from_peer_addr=args.recv_from_peer_addr,
send_to_peer_addr=args.send_to_peer_addr,
Expand Down
36 changes: 33 additions & 3 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dataclasses
import enum
import json
import logging
import threading
import time
Expand All @@ -17,7 +18,7 @@
import dijkstar
import httpx
import zmq
from lattica import ConnectionHandler, Lattica, rpc_method, rpc_stream
from lattica import ConnectionHandler, Lattica, rpc_method, rpc_stream, rpc_stream_iter

from backend.server.rpc_connection_handler import RPCConnectionHandler
from parallax.p2p.proto import forward_pb2
Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
send_to_peer_addr: str,
block_start_index: int,
block_end_index: int,
http_port: Optional[int] = None,
notify_url: Optional[str] = None,
):
# Initialize the base class
Expand All @@ -113,6 +115,7 @@ def __init__(
self.send_to_peer_addr = send_to_peer_addr
self.block_start_index = block_start_index
self.block_end_index = block_end_index
self.http_port = http_port
self.notify_url = notify_url
self._recv_from_peer = None
self._recv_from_peer_lock = threading.Lock()
Expand Down Expand Up @@ -153,6 +156,33 @@ def rpc_abort(
logger.exception(f"Error in rpc_abort: {e}")
return forward_pb2.AbortResponse()

@rpc_stream_iter
def chat_completion(
self,
request,
):
"""Handle chat completion request"""
logger.debug(f"Chat completion request: {request}, type: {type(request)}")
try:
if request.get("stream", False):
with httpx.Client(timeout=20 * 60 * 60) as client:
with client.stream(
"POST",
f"http://localhost:{self.http_port}/v1/chat/completions",
json=request,
) as response:
for chunk in response.iter_bytes():
if chunk:
yield chunk
else:
with httpx.Client(timeout=20 * 60 * 60) as client:
response = client.post(
f"http://localhost:{self.http_port}/v1/chat/completions", json=request
).json()
yield json.dumps(response).encode()
except Exception as e:
logger.exception(f"Error in chat completion: {e}")


class GradientServer:
"""
Expand Down Expand Up @@ -313,6 +343,7 @@ def _publish_metrics(_snapshot):
send_to_peer_addr=self.send_to_peer_addr,
block_start_index=self.block_start_index,
block_end_index=self.block_end_index,
http_port=self.http_port,
notify_url=self.notify_url,
) # thread

Expand Down Expand Up @@ -546,7 +577,7 @@ def get_node_info(self, is_update: bool = False):
if time.time() - self.rtt_last_update > self.rtt_update_interval:
self.rtts = {}
all_peers = []
for _ in range(1 if is_update else 30):
for _ in range(1 if is_update else 10):
all_peers = self.lattica.get_all_peers()
if len(all_peers) > 0 and self.scheduler_peer_id in all_peers:
break
Expand Down Expand Up @@ -577,7 +608,6 @@ def get_node_info(self, is_update: bool = False):
self.rtt_last_update = time.time()

info = {
"http_port": f"{self.http_port}",
"node_id": self.lattica.peer_id(),
"hardware": detect_node_hardware(self.lattica.peer_id()),
"kv_cache_ratio": 0.25,
Expand Down
5 changes: 0 additions & 5 deletions src/parallax/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@ def parse_args() -> argparse.Namespace:

# HTTP server configuration
parser.add_argument("--host", type=str, default="localhost", help="Host of the HTTP server.")

parser.add_argument("--port", type=int, default=3000, help="Port of the HTTP server")

parser.add_argument(
"--announce-http-port", type=str, default=None, help="HTTP port to announce"
)

# P2P configuration
parser.add_argument("--initial-peers", nargs="+", default=[], help="List of initial DHT peers")

Expand Down