Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ An example of serving Qwen3-0.6B with 2-nodes:
python3 ./parallax/src/parallax/launch.py \
--model-path Qwen/Qwen3-0.6B \
--port 3000 \
--dht-port 5000 \
--max-batch-size 8 \
--start-layer 0 \
--end-layer 14
Expand All @@ -218,7 +217,6 @@ python3 ./parallax/src/parallax/launch.py \
python3 ./parallax/src/parallax/launch.py \
--model-path Qwen/Qwen3-0.6B \
--port 3000 \
--dht-port 5000 \
--max-batch-size 8 \
--start-layer 14 \
--end-layer 28
Expand Down
2 changes: 1 addition & 1 deletion src/backend/benchmark/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Adapted from vLLM: https://github.com/vllm-project/vllm/blob/v0.7.2/benchmarks/benchmark_serving.py

On the server side (parallax scheduler with oAI API server), run
python src/backend/main.py --dht-port 31328 --port 31328 --init-nodes-num 1
python src/backend/main.py --port 31328 --init-nodes-num 1

On the worker side (parallax worker nodes),
1. Get `scheduler-addr` get from scheduler launching output
Expand Down
13 changes: 4 additions & 9 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,20 +130,15 @@ async def serve_index():
logger.info(f"args: {args}")
if args.log_level != "DEBUG":
display_parallax_run()
host_maddrs = args.host_maddrs
dht_port = args.dht_port
if args.dht_port is not None:
assert host_maddrs is None, "You can't use --dht-port and --host-maddrs at the same time"
else:
dht_port = 0
if host_maddrs is None:
host_maddrs = [f"/ip4/0.0.0.0/tcp/{dht_port}", f"/ip6/::/tcp/{dht_port}"]

scheduler_manage = SchedulerManage(
initial_peers=args.initial_peers,
relay_servers=args.relay_servers,
dht_prefix=args.dht_prefix,
host_maddrs=host_maddrs,
host_maddrs=[
f"/ip4/0.0.0.0/tcp/{args.tcp_port}",
f"/ip4/0.0.0.0/udp/{args.udp_port}/quic-v1",
],
announce_maddrs=args.announce_maddrs,
)

Expand Down
18 changes: 4 additions & 14 deletions src/backend/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,37 +10,27 @@ def parse_args() -> argparse.Namespace:
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

# P2P configuration
# Lattica configuration
parser.add_argument("--initial-peers", nargs="+", default=[], help="List of initial DHT peers")

parser.add_argument("--relay-servers", nargs="+", default=[], help="List of relay DHT peers")

parser.add_argument(
"--announce-maddrs", nargs="+", default=[], help="List of multiaddresses to announce"
)

parser.add_argument("--dht-port", type=int, default=None, help="Port for DHT communication")

parser.add_argument("--host-maddrs", type=str, default=None, help="Multiaddress to host")

parser.add_argument("--tcp-port", type=int, default=0, help="Port for Lattica TCP listening")
parser.add_argument("--udp-port", type=int, default=0, help="Port for Lattica UDP listening")
parser.add_argument("--dht-prefix", type=str, default="gradient", help="Prefix for DHT keys")

parser.add_argument("--public-ip", type=str, default=None, help="Public IP address to announce")

# Scheduler configuration
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")

parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Log level",
)

parser.add_argument("--model-name", type=str, default=None, help="Model name")

parser.add_argument("--init-nodes-num", type=int, default=None, help="Number of initial nodes")

parser.add_argument(
"--is-local-network", type=bool, default=True, help="Whether to use local network"
)
Expand Down
2 changes: 0 additions & 2 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ def run_command(args, passthrough_args: list[str] | None = None):
# Build the command to run the backend main.py
passthrough_args = passthrough_args or []
cmd = [sys.executable, str(backend_main)]
if not _flag_present(passthrough_args, ["--dht-port"]):
cmd.extend(["--dht-port", "5001"])
if not _flag_present(passthrough_args, ["--port"]):
cmd.extend(["--port", "3001"])

Expand Down
15 changes: 7 additions & 8 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@

Example command:
python src/parallax/launch.py \
--model-path Qwen/Qwen3-0.6B-MLX-bf16 \
--model-path Qwen/Qwen3-0.6B \
--max-num-tokens-per-batch 16384 \
--max-batch-size 128 \
--start-layer 14 \
--end-layer 28 \
--initial-peers {peer of GPU which hold the first half model}
--start-layer 0 \
--end-layer 28
"""

import multiprocessing
Expand Down Expand Up @@ -82,9 +81,9 @@
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=executor.config.get("num_hidden_layers"),
dht_port=args.dht_port,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
http_port=args.port,
notify_url=args.notify_url,
Expand All @@ -102,9 +101,9 @@
pp_start_layer=None,
pp_end_layer=None,
hidden_layers=None,
dht_port=args.dht_port,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
host_maddrs=args.host_maddrs,
announce_maddrs=args.announce_maddrs,
http_port=args.port,
notify_url=args.notify_url,
Expand Down
17 changes: 4 additions & 13 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,8 +354,7 @@ def find_servers(self):
"""Find available servers in the DHT network"""
# Find all announced blocks
server_blocks = []
block_announced_key = f"{self.dht_prefix}_announce"
block_servers = self.lattica.get(block_announced_key)
block_servers = self.lattica.get(self.prefix_id)
if block_servers is None:
return []
for peer_id, value in block_servers.value.items():
Expand Down Expand Up @@ -652,9 +651,9 @@ def launch_p2p_server(
pp_start_layer: int,
pp_end_layer: int,
hidden_layers: int,
dht_port: Optional[int],
tcp_port: int,
udp_port: int,
dht_prefix: str,
host_maddrs: Optional[List[str]],
announce_maddrs: List[str],
http_port: Optional[int],
notify_url: str,
Expand All @@ -664,14 +663,6 @@ def launch_p2p_server(
max_batch_size: Optional[int] = None,
max_sequence_length: Optional[int] = None,
):
if dht_port is not None:
assert host_maddrs is None, "You can't use --dht-port and --host-maddrs at the same time"
else:
dht_port = 0
if host_maddrs is None:
host_maddrs = [f"/ip4/0.0.0.0/tcp/{dht_port}", f"/ip4/0.0.0.0/udp/{dht_port}/quic-v1"]

# Run the server in a separate thread to keep the main thread free for event loop
server = GradientServer(
recv_from_peer_addr=recv_from_peer_addr,
send_to_peer_addr=send_to_peer_addr,
Expand All @@ -682,7 +673,7 @@ def launch_p2p_server(
block_end_index=pp_end_layer,
hidden_layers=hidden_layers,
dht_prefix=dht_prefix,
host_maddrs=host_maddrs,
host_maddrs=[f"/ip4/0.0.0.0/tcp/{tcp_port}", f"/ip4/0.0.0.0/udp/{udp_port}/quic-v1"],
announce_maddrs=announce_maddrs,
http_port=http_port,
notify_url=notify_url,
Expand Down
13 changes: 3 additions & 10 deletions src/parallax/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,16 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--host", type=str, default="localhost", help="Host of the HTTP server.")
parser.add_argument("--port", type=int, default=3000, help="Port of the HTTP server")

# P2P configuration
# Lattica configuration
parser.add_argument("--initial-peers", nargs="+", default=[], help="List of initial DHT peers")

parser.add_argument("--scheduler-addr", type=str, default=None, help="Scheduler address")

parser.add_argument("--relay-servers", nargs="+", default=[], help="List of relay DHT peers")

parser.add_argument("--dht-port", type=int, default=None, help="Port for DHT communication")

parser.add_argument("--host-maddrs", type=str, default=None, help="Multiaddress to host")

parser.add_argument("--tcp-port", type=int, default=0, help="Port for Lattica TCP listening")
parser.add_argument("--udp-port", type=int, default=0, help="Port for Lattica UDP listening")
parser.add_argument(
"--announce-maddrs", nargs="+", default=[], help="List of multiaddresses to announce"
)

parser.add_argument("--dht-prefix", type=str, default="gradient", help="Prefix for DHT keys")

parser.add_argument(
"--notify-url", type=str, default=None, help="URL to notify when a request is finished"
)
Expand Down