Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions src/backend/server/rpc_connection_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,18 @@ def get_layer_allocation(self, current_node_id):
list_node_allocations = self.scheduler.list_node_allocations()
for node_id, start_layer, end_layer in list_node_allocations:
if current_node_id == node_id:
return {
"node_id": node_id,
"model_name": self.scheduler.model_info.model_name,
"start_layer": start_layer,
"end_layer": end_layer,
}
node = self.scheduler.node_id_to_node.get(node_id)
if node:
return {
"node_id": node_id,
"model_name": (
node.model_info.model_name
if node.hardware.device != "mlx"
else node.model_info.mlx_model_name
),
"start_layer": start_layer,
"end_layer": end_layer,
}
return {}

def build_node(self, node_json: dict):
Expand Down Expand Up @@ -177,10 +183,12 @@ def build_hardware(self, hardware_json):
gpu_name = hardware_json.get("gpu_name")
memory_gb = hardware_json.get("memory_gb")
memory_bandwidth_gbps = hardware_json.get("memory_bandwidth_gbps")
device = hardware_json.get("device")
return NodeHardwareInfo(
node_id=node_id,
tflops_fp16=tflops_fp16,
gpu_name=gpu_name,
memory_gb=memory_gb,
memory_bandwidth_gbps=memory_bandwidth_gbps,
device=device,
)
119 changes: 65 additions & 54 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,60 @@
import json
import logging
from pathlib import Path

from scheduling.model_info import ModelInfo

# Supported model list
MODEL_LIST = [
"Qwen/Qwen3-0.6B",
"openai/gpt-oss-20b",
"openai/gpt-oss-120b",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8",
"Qwen/Qwen3-0.6B-FP8",
"Qwen/Qwen3-1.7B",
"Qwen/Qwen3-1.7B-FP8",
"Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-FP8",
"Qwen/Qwen3-4B-Instruct-2507",
"Qwen/Qwen3-4B-Instruct-2507-FP8",
"Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Thinking-2507-FP8",
"Qwen/Qwen3-8B",
"Qwen/Qwen3-8B-FP8",
"Qwen/Qwen3-14B",
"Qwen/Qwen3-14B-FP8",
"Qwen/Qwen3-32B",
"Qwen/Qwen3-32B-FP8",
"Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B-Instruct-2507-FP8",
"Qwen/Qwen3-30B-A3B-Thinking-2507-FP8",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8",
"Qwen/Qwen3-235B-A22B-GPTQ-Int4",
"Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
"nvidia/Llama-3.3-70B-Instruct-FP8",
"nvidia/Llama-3.1-70B-Instruct-FP8",
"nvidia/Llama-3.1-8B-Instruct-FP8",
"deepseek-ai/DeepSeek-V3.1",
"deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-V2",
"MiniMaxAI/MiniMax-M2",
"zai-org/GLM-4.6",
]

# Supported model list - key: model name, value: MLX model name (same as key if no MLX variant)
MODELS = {
"Qwen/Qwen3-0.6B": "Qwen/Qwen3-0.6B",
"openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8",
"openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-4bit",
"moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit",
"moonshotai/Kimi-K2-Instruct-0905": "mlx-community/Kimi-K2-Instruct-0905-mlx-DQ3_K_M",
"Qwen/Qwen3-Next-80B-A3B-Instruct": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
"Qwen/Qwen3-0.6B-FP8": "Qwen/Qwen3-0.6B",
"Qwen/Qwen3-1.7B": "Qwen/Qwen3-1.7B",
"Qwen/Qwen3-1.7B-FP8": "Qwen/Qwen3-1.7B",
"Qwen/Qwen3-4B": "Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-FP8": "Qwen/Qwen3-4B",
"Qwen/Qwen3-4B-Instruct-2507": "Qwen/Qwen3-4B-Instruct-2507",
"Qwen/Qwen3-4B-Instruct-2507-FP8": "Qwen/Qwen3-4B-Instruct-2507-FP8",
"Qwen/Qwen3-4B-Thinking-2507": "Qwen/Qwen3-4B-Thinking-2507",
"Qwen/Qwen3-4B-Thinking-2507-FP8": "Qwen/Qwen3-4B-Thinking-2507-FP8",
"Qwen/Qwen3-8B": "Qwen/Qwen3-8B",
"Qwen/Qwen3-8B-FP8": "Qwen/Qwen3-8B-FP8",
"Qwen/Qwen3-14B": "Qwen/Qwen3-14B",
"Qwen/Qwen3-14B-FP8": "Qwen/Qwen3-14B-FP8",
"Qwen/Qwen3-32B": "Qwen/Qwen3-32B",
"Qwen/Qwen3-32B-FP8": "Qwen/Qwen3-32B-FP8",
"Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B",
"Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8",
"Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-4bit",
"Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit",
"Qwen/Qwen2.5-0.5B-Instruct": "Qwen/Qwen2.5-0.5B-Instruct",
"Qwen/Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-3B-Instruct": "Qwen/Qwen2.5-3B-Instruct",
"Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-14B-Instruct": "Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-32B-Instruct": "Qwen/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct",
"nvidia/Llama-3.3-70B-Instruct-FP8": "nvidia/Llama-3.3-70B-Instruct-FP8",
"nvidia/Llama-3.1-70B-Instruct-FP8": "nvidia/Llama-3.1-70B-Instruct-FP8",
"nvidia/Llama-3.1-8B-Instruct-FP8": "nvidia/Llama-3.1-8B-Instruct-FP8",
"deepseek-ai/DeepSeek-V3.1": "deepseek-ai/DeepSeek-V3.1",
"deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-V2": "deepseek-ai/DeepSeek-V2",
"MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit",
"zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit",
}

logger = logging.getLogger(__name__)
NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join"""

NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """
Expand All @@ -76,6 +78,8 @@ def _load_config_only(name: str) -> dict:
config = _load_config_only(model_name)

# get quant method
# logger.info(f"Loading model config from {model_name}")

quant_method = config.get("quant_method", None)
quantization_config = config.get("quantization_config", None)
if quant_method is None and quantization_config is not None:
Expand All @@ -88,9 +92,14 @@ def _load_config_only(name: str) -> dict:
elif quant_method in ("mxfp4", "int4", "awq", "gptq"):
param_bytes_per_element = 0.5

# Only for hack, fix it when support different quantization bits
# if "minimax-m2" in model_name.lower():
# param_bytes_per_element = 0.5
mlx_param_bytes_per_element = param_bytes_per_element
mlx_model_name = MODELS.get(model_name, model_name)

if mlx_model_name != model_name:
mlx_config = _load_config_only(mlx_model_name)
mlx_quant_dict = mlx_config.get("quantization_config", None)
if mlx_quant_dict and "bits" in mlx_quant_dict:
mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8

# get local experts
num_local_experts = config.get("num_local_experts", None)
Expand All @@ -101,6 +110,7 @@ def _load_config_only(name: str) -> dict:

model_info = ModelInfo(
model_name=model_name,
mlx_model_name=mlx_model_name,
head_size=config.get("head_dim", 128),
qk_nope_head_dim=config.get("qk_nope_head_dim", None),
qk_rope_head_dim=config.get("qk_rope_head_dim", None),
Expand All @@ -112,6 +122,7 @@ def _load_config_only(name: str) -> dict:
num_layers=config.get("num_hidden_layers", 0),
ffn_num_projections=3,
param_bytes_per_element=param_bytes_per_element,
mlx_param_bytes_per_element=mlx_param_bytes_per_element,
cache_bytes_per_element=2,
embedding_bytes_per_element=2,
num_local_experts=num_local_experts,
Expand All @@ -122,7 +133,7 @@ def _load_config_only(name: str) -> dict:


def get_model_list():
return MODEL_LIST
return list(MODELS.keys())


def get_node_join_command(scheduler_addr, is_local_network):
Expand Down
26 changes: 0 additions & 26 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,11 @@
from parallax.server.executor import Executor
from parallax.server.http_server import launch_http_server
from parallax.server.server_args import parse_args
from parallax.utils.utils import get_current_device
from parallax_utils.ascii_anime import display_parallax_join
from parallax_utils.logging_config import get_logger, set_log_level

logger = get_logger("parallax.launch")

"""Currently hard code model name for MAC"""
MLX_MODEL_NAME_MAP = {
"openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8",
"openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-4bit",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-4bit",
"Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit",
"moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit",
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx",
"MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit",
"zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit",
}

if __name__ == "__main__":
multiprocessing.set_start_method("spawn", force=True)
Expand All @@ -64,12 +49,6 @@

logger.debug(f"executor_input_addr: {args.executor_input_ipc}")
logger.debug(f"executor_output_addr: {args.executor_output_ipc}")
# Hard code for mlx-community models
if get_current_device() == "mlx":
mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None)
if mlx_model_repo is not None:
args.model_path = mlx_model_repo
logger.debug(f"Replace mlx model path: {mlx_model_repo}")
if args.scheduler_addr is None:
if args.log_level != "DEBUG":
display_parallax_join(args.model_path)
Expand Down Expand Up @@ -122,11 +101,6 @@
args.end_layer = gradient_server.block_end_index
args.model_path = gradient_server.model_name
# Hard code for mlx-community models
if get_current_device() == "mlx":
mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None)
if mlx_model_repo is not None:
args.model_path = mlx_model_repo
logger.debug(f"Replace mlx model path: {mlx_model_repo}")
logger.debug(
f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}"
)
Expand Down
4 changes: 4 additions & 0 deletions src/parallax/server/server_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
"gpu_name": "Unknown",
"memory_gb": 16.0,
"memory_bandwidth_gbps": 100.0,
"device": "Unknown",
}

if isinstance(hw, NvidiaHardwareInfo):
Expand All @@ -191,6 +192,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
"gpu_name": hw.chip,
"memory_gb": hw.vram_gb,
"memory_bandwidth_gbps": hw.memory_bandwidth_gbps,
"device": "cuda",
}
if isinstance(hw, AppleSiliconHardwareInfo):
# Use unified memory size as memory_gb; bandwidth rough estimate per family
Expand All @@ -201,6 +203,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
"gpu_name": hw.chip,
"memory_gb": hw.total_ram_gb,
"memory_bandwidth_gbps": est_bandwidth,
"device": "mlx",
}
# Generic fallback
return {
Expand All @@ -209,6 +212,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]:
"gpu_name": "Unknown",
"memory_gb": 16.0,
"memory_bandwidth_gbps": 100.0,
"device": "Unknown",
}


Expand Down
6 changes: 6 additions & 0 deletions src/scheduling/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ModelInfo:
"""

model_name: str
mlx_model_name: str
head_size: int
hidden_dim: int
intermediate_dim: int
Expand All @@ -37,6 +38,7 @@ class ModelInfo:
tie_embedding: bool = False
# Default int8
param_bytes_per_element: float = 1
mlx_param_bytes_per_element: float = 1
cache_bytes_per_element: int = 1
embedding_bytes_per_element: int = 1

Expand Down Expand Up @@ -70,6 +72,10 @@ def k_dim(self) -> int:
"""Return key head dim."""
return self.num_kv_heads * self.head_size_k

@property
def mlx_bit_factor(self) -> float:
return self.mlx_param_bytes_per_element / self.param_bytes_per_element

@property
def embedding_io_bytes(self) -> int:
"""Estimate memory for input_embeddings / or lm_head."""
Expand Down
17 changes: 14 additions & 3 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class NodeHardwareInfo:
gpu_name: str
memory_gb: float
memory_bandwidth_gbps: float
device: str


@dataclass
Expand Down Expand Up @@ -294,9 +295,19 @@ def get_decoder_layer_capacity(
available_memory_bytes,
self.model_info.decoder_layer_io_bytes(roofline=False),
)
return floor(
available_memory_bytes / self.model_info.decoder_layer_io_bytes(roofline=False)
)
if self.hardware.device == "mlx":
# For mlx, consider mlx bit factor
return floor(
available_memory_bytes
/ (
self.model_info.decoder_layer_io_bytes(roofline=False)
* self.model_info.mlx_bit_factor
)
)
else:
return floor(
available_memory_bytes / self.model_info.decoder_layer_io_bytes(roofline=False)
)

@property
def per_decoder_layer_kv_cache_memory(self) -> Optional[int]:
Expand Down
8 changes: 4 additions & 4 deletions tests/scheduler_tests/test_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@

def _build_node(gpu_type: str, model: ModelInfo, id_suffix: str = "") -> Node:
hw_map = {
"a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0),
"a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0),
"rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0),
"rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0),
"a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0, "cuda"),
"a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0, "cuda"),
"rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0, "cuda"),
"rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0, "cuda"),
}
hw = hw_map[gpu_type]
return Node(node_id=hw.node_id, hardware=hw, model_info=model)
Expand Down
1 change: 1 addition & 0 deletions tests/scheduler_tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def _build_node(node_id: str, model: ModelInfo, *, tflops: float, mem_gb: float)
gpu_name="",
memory_gb=mem_gb,
memory_bandwidth_gbps=1000.0,
device="cuda",
)
n = Node(node_id=node_id, hardware=hw, model_info=model)
# Ensure latency estimation uses a defined speedup
Expand Down
Loading