Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions docker/Dockerfile.blackwell
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM lmsysorg/sglang:v0.5.3rc1
FROM lmsysorg/sglang:v0.5.4.post1

ENV SGL_ENABLE_JIT_DEEPGEMM=0
ENV SGLANG_ENABLE_JIT_DEEPGEMM=0

WORKDIR /parallax

Expand All @@ -9,4 +9,3 @@ COPY src ./src
COPY pyproject.toml ./pyproject.toml

RUN pip install -e '.[gpu]'
RUN pip install https://github.com/sgl-project/whl/releases/download/v0.3.7/sgl_kernel-0.3.7+cu128-cp310-abi3-manylinux2014_x86_64.whl --force-reinstall
4 changes: 2 additions & 2 deletions docker/Dockerfile.hopper
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM lmsysorg/sglang:v0.5.3rc1
FROM lmsysorg/sglang:v0.5.4.post1

ENV SGL_ENABLE_JIT_DEEPGEMM=0
ENV SGLANG_ENABLE_JIT_DEEPGEMM=0

WORKDIR /parallax

Expand Down
3 changes: 3 additions & 0 deletions src/backend/server/rpc_connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def get_layer_allocation(self, current_node_id):
),
"start_layer": start_layer,
"end_layer": end_layer,
"tp_size": node.hardware.num_gpus,
}
return {}

Expand Down Expand Up @@ -182,13 +183,15 @@ def build_node(self, node_json: dict):

def build_hardware(self, hardware_json):
node_id = hardware_json.get("node_id")
num_gpus = hardware_json.get("num_gpus")
tflops_fp16 = hardware_json.get("tflops_fp16")
gpu_name = hardware_json.get("gpu_name")
memory_gb = hardware_json.get("memory_gb")
memory_bandwidth_gbps = hardware_json.get("memory_bandwidth_gbps")
device = hardware_json.get("device")
return NodeHardwareInfo(
node_id=node_id,
num_gpus=num_gpus,
tflops_fp16=tflops_fp16,
gpu_name=gpu_name,
memory_gb=memory_gb,
Expand Down
1 change: 1 addition & 0 deletions src/backend/server/scheduler_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def build_node_info(self, node):
return {
"node_id": node.node_id,
"status": NODE_STATUS_AVAILABLE if node.is_active else NODE_STATUS_WAITING,
"gpu_num": node.hardware.num_gpus,
"gpu_name": node.hardware.gpu_name,
"gpu_memory": node.hardware.memory_gb,
}
Expand Down
2 changes: 1 addition & 1 deletion src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def join_command(args, passthrough_args: list[str] | None = None):

# Set environment variable for the subprocess
env = os.environ.copy()
env["SGL_ENABLE_JIT_DEEPGEMM"] = "0"
env["SGLANG_ENABLE_JIT_DEEPGEMM"] = "0"

# Build the command to run the launch.py script
passthrough_args = passthrough_args or []
Expand Down
56 changes: 43 additions & 13 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
Launch the Parallax server.

This script is used to launch the Parallax server.
It will start the P2P server and the executor.
It will start the following services:
1.Executor with tp_rank=0 in the main process.
2.Executor with tp_rank>0, each tp_rank as a subprocess.
3.HTTP server as a subprocess.
4.P2P server as a thread in the main process.

Example command:
python src/parallax/launch.py \
Expand Down Expand Up @@ -41,7 +45,7 @@
gradient_server = None
http_server_process = None
executor = None
executor_procs = []
executor_subprocs = []
try:
args = parse_args()
set_log_level(args.log_level)
Expand Down Expand Up @@ -75,6 +79,7 @@
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=config.get("num_hidden_layers"),
tp_size=args.tp_size,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
Expand All @@ -91,18 +96,21 @@
)
if gradient_server is not None:
gradient_server.status = ServerState.READY
tp_rank_range = range(args.tp_size)
for tp_rank in tp_rank_range:

# For each tp_rank > 0, create a subprocess and run executor
for tp_rank in range(1, args.tp_size):
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
)
proc.start()
executor_procs.append(proc)
for executor_process in executor_procs:
executor_process.join()
executor_subprocs.append(proc)
# Launch executor with tp_rank=0 in the main process
args.tp_rank = 0
executor = Executor.create_from_args(args)
executor.run_loop()
else:
gradient_server = launch_p2p_server(
initial_peers=args.initial_peers,
Expand All @@ -111,6 +119,7 @@
pp_start_layer=args.start_layer,
pp_end_layer=args.end_layer,
hidden_layers=None,
tp_size=args.tp_size,
tcp_port=args.tcp_port,
udp_port=args.udp_port,
dht_prefix=args.dht_prefix,
Expand All @@ -128,9 +137,7 @@
args.start_layer = gradient_server.block_start_index
args.end_layer = gradient_server.block_end_index
args.model_path = gradient_server.model_name
# TODO: Implement inter-process communication to enable TP.
# For scheduler mode, currently only support tp_rank=0
args.tp_rank = 0
args.tp_size = gradient_server.tp_size

logger.debug(
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
Expand All @@ -148,6 +155,18 @@
# Main execution loop with layer reallocation support
while True:
try:
# For each tp_rank > 0, create a subprocess and run executor
for tp_rank in range(1, args.tp_size):
args_copy = argparse.Namespace(**vars(args))
args_copy.tp_rank = tp_rank
proc = multiprocessing.Process(
target=run_executor_process,
args=(args_copy,),
)
proc.start()
executor_subprocs.append(proc)
# Launch executor with tp_rank=0 in the main process
args.tp_rank = 0
executor = Executor.create_from_args(args, gradient_server=gradient_server)
if gradient_server is not None:
gradient_server.status = ServerState.READY
Expand All @@ -159,7 +178,18 @@
logger.warning(
"Layer allocation changed! Reloading executor with new layers..."
)

# shutdown all executor processes
thread_pool = []
for executor_process in executor_subprocs:
t = threading.Thread(
target=stop_executor_process, args=(executor_process,)
)
t.start()
thread_pool.append(t)
executor.shutdown()
for t in thread_pool:
t.join()

if args.start_layer == 0:
http_server_process = stop_http_server(http_server_process)
Expand Down Expand Up @@ -210,13 +240,13 @@
if gradient_server is not None:
gradient_server.shutdown()

# Shutdown executor subprocess for scheduler mode
for executor_process in executor_procs:
# Shutdown executor subprocesses
for executor_process in executor_subprocs:
t = threading.Thread(target=stop_executor_process, args=(executor_process,))
t.start()
thread_pool.append(t)

# Shutdown executor main process for non-scheduler mode
# Shutdown executor main process
if executor is not None:
executor.shutdown()

Expand Down
5 changes: 5 additions & 0 deletions src/parallax/p2p/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
block_start_index: int = 0,
block_end_index: int = 1,
hidden_layers: int = 128,
tp_size: int = 1,
dht_prefix: str = "gradient",
host_maddrs: List[str] = [],
http_port: Optional[int] = None,
Expand All @@ -220,6 +221,7 @@ def __init__(
self.block_start_index = block_start_index
self.block_end_index = block_end_index
self.hidden_layers = hidden_layers
self.tp_size = tp_size
self.dht_prefix = dht_prefix
self.host_maddrs = host_maddrs
self.announce_maddrs = announce_maddrs
Expand Down Expand Up @@ -346,6 +348,7 @@ def run(self):
self.block_start_index = response.get("start_layer")
self.block_end_index = response.get("end_layer")
self.model_name = response.get("model_name")
self.tp_size = response.get("tp_size")

# Publish executor metrics to backend on each update
def _publish_metrics(_snapshot):
Expand Down Expand Up @@ -738,6 +741,7 @@ def launch_p2p_server(
pp_start_layer: int,
pp_end_layer: int,
hidden_layers: int,
tp_size: int,
tcp_port: int,
udp_port: int,
dht_prefix: str,
Expand All @@ -761,6 +765,7 @@ def launch_p2p_server(
block_start_index=pp_start_layer,
block_end_index=pp_end_layer,
hidden_layers=hidden_layers,
tp_size=tp_size,
dht_prefix=dht_prefix,
host_maddrs=[f"/ip4/0.0.0.0/tcp/{tcp_port}", f"/ip4/0.0.0.0/udp/{udp_port}/quic-v1"],
announce_maddrs=announce_maddrs,
Expand Down
9 changes: 8 additions & 1 deletion src/parallax/server/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class HardwareInfo:
total_ram_gb: float
chip: str
tflops_fp16: float
num_gpus: int

def dumps(self) -> Dict[str, Any]:
"""Serializes the HardwareInfo object to a dictionary."""
Expand Down Expand Up @@ -99,7 +100,7 @@ def detect(cls) -> "AppleSiliconHardwareInfo":
"Please add it to the _APPLE_PEAK_FP16 dictionary."
) from e

return cls(total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops)
return cls(num_gpus=1, total_ram_gb=round(total_gb, 1), chip=chip, tflops_fp16=flops)


@dataclass
Expand Down Expand Up @@ -143,6 +144,7 @@ def detect(cls) -> "NvidiaHardwareInfo":
if torch is None or not torch.cuda.is_available():
raise RuntimeError("CUDA not available; cannot detect NVIDIA hardware")

device_count = torch.cuda.device_count()
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
name = getattr(props, "name", f"cuda:{device_index}")
Expand All @@ -156,6 +158,7 @@ def detect(cls) -> "NvidiaHardwareInfo":

spec = cls._match_gpu_specs(name, total_vram_gb)
return cls(
num_gpus=device_count,
total_ram_gb=round(total_gb, 1),
chip=name,
tflops_fp16=float(spec["tflops_fp16"]),
Expand All @@ -179,6 +182,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
# Fallback to a conservative default
return {
"node_id": node_id,
"num_gpus": 1,
"tflops_fp16": 50.0,
"gpu_name": "Unknown",
"memory_gb": 16.0,
Expand All @@ -189,6 +193,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
if isinstance(hw, NvidiaHardwareInfo):
return {
"node_id": node_id,
"num_gpus": hw.num_gpus,
"tflops_fp16": hw.tflops_fp16,
"gpu_name": hw.chip,
"memory_gb": hw.vram_gb,
Expand All @@ -200,6 +205,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
est_bandwidth = 100.0
return {
"node_id": node_id,
"num_gpus": hw.num_gpus,
"tflops_fp16": hw.tflops_fp16,
"gpu_name": hw.chip,
"memory_gb": hw.total_ram_gb,
Expand All @@ -209,6 +215,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
# Generic fallback
return {
"node_id": node_id,
"num_gpus": hw.num_gpus,
"tflops_fp16": hw.tflops_fp16,
"gpu_name": "Unknown",
"memory_gb": 16.0,
Expand Down
4 changes: 2 additions & 2 deletions src/scheduling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,13 @@ model = ModelInfo( # instantiate with your model's parameters

n0 = Node(
node_id="node-0",
hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
hardware=NodeHardwareInfo(node_id="node-0", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
model_info=model,
)

n1 = Node(
node_id="node-1",
hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
hardware=NodeHardwareInfo(node_id="node-1", tflops_fp16=180.0, num_gpus=1, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039.0),
model_info=model,
)

Expand Down
4 changes: 3 additions & 1 deletion src/scheduling/layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def should_global_rebalance(self) -> bool:
if len(layer_heap) < 2:
return False

total_cluster_memory = sum(node.hardware.memory_gb for node in self.nodes)
total_cluster_memory = sum(
(node.hardware.num_gpus * node.hardware.memory_gb) for node in self.nodes
)

if total_cluster_memory == 0:
raise ValueError("Total cluster memory is zero")
Expand Down
17 changes: 15 additions & 2 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class NodeHardwareInfo:
"""

node_id: str
num_gpus: int
tflops_fp16: float
gpu_name: str
memory_gb: float
Expand Down Expand Up @@ -272,7 +273,12 @@ def get_decoder_layer_capacity(
Capacity is measured using the parameter memory budget on the device.
"""
available_memory_bytes = floor(
self.hardware.memory_gb * 1024 * 1024 * 1024 * self.param_hosting_ratio
self.hardware.num_gpus
* self.hardware.memory_gb
* 1024
* 1024
* 1024
* self.param_hosting_ratio
)
if include_input_embed:
available_memory_bytes -= self.model_info.embedding_io_bytes
Expand Down Expand Up @@ -300,7 +306,14 @@ def per_decoder_layer_kv_cache_memory(self) -> Optional[int]:
if self.num_current_layers == 0:
return None
return floor(
(self.hardware.memory_gb * 1024 * 1024 * 1024 * self.kv_cache_ratio)
(
self.hardware.num_gpus
* self.hardware.memory_gb
* 1024
* 1024
* 1024
* self.kv_cache_ratio
)
/ self.num_current_layers
)

Expand Down
8 changes: 4 additions & 4 deletions tests/scheduler_tests/test_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

def _build_node(gpu_type: str, model: ModelInfo, id_suffix: str = "") -> Node:
hw_map = {
"a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0, "cuda"),
"a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0, "cuda"),
"rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0, "cuda"),
"rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0, "cuda"),
"a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 1, 312.0, "", 80.0, 2039.0, "cuda"),
"a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 1, 312.0, "", 40.0, 1935.0, "cuda"),
"rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 1, 165, "", 32.0, 1792.0, "cuda"),
"rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 1, 82.6, "", 24.0, 1008.0, "cuda"),
}
hw = hw_map[gpu_type]
return Node(node_id=hw.node_id, hardware=hw, model_info=model)
Expand Down
Loading