diff --git a/src/backend/server/rpc_connection_handler.py b/src/backend/server/rpc_connection_handler.py index 0fee9922..0288f26c 100644 --- a/src/backend/server/rpc_connection_handler.py +++ b/src/backend/server/rpc_connection_handler.py @@ -140,12 +140,18 @@ def get_layer_allocation(self, current_node_id): list_node_allocations = self.scheduler.list_node_allocations() for node_id, start_layer, end_layer in list_node_allocations: if current_node_id == node_id: - return { - "node_id": node_id, - "model_name": self.scheduler.model_info.model_name, - "start_layer": start_layer, - "end_layer": end_layer, - } + node = self.scheduler.node_id_to_node.get(node_id) + if node: + return { + "node_id": node_id, + "model_name": ( + node.model_info.model_name + if node.hardware.device != "mlx" + else node.model_info.mlx_model_name + ), + "start_layer": start_layer, + "end_layer": end_layer, + } return {} def build_node(self, node_json: dict): @@ -177,10 +183,12 @@ def build_hardware(self, hardware_json): gpu_name = hardware_json.get("gpu_name") memory_gb = hardware_json.get("memory_gb") memory_bandwidth_gbps = hardware_json.get("memory_bandwidth_gbps") + device = hardware_json.get("device") return NodeHardwareInfo( node_id=node_id, tflops_fp16=tflops_fp16, gpu_name=gpu_name, memory_gb=memory_gb, memory_bandwidth_gbps=memory_bandwidth_gbps, + device=device, ) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index d2691c39..5974e221 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -1,58 +1,60 @@ import json +import logging from pathlib import Path from scheduling.model_info import ModelInfo -# Supported model list -MODEL_LIST = [ - "Qwen/Qwen3-0.6B", - "openai/gpt-oss-20b", - "openai/gpt-oss-120b", - "moonshotai/Kimi-K2-Instruct", - "moonshotai/Kimi-K2-Instruct-0905", - "Qwen/Qwen3-Next-80B-A3B-Instruct", - "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8", - "Qwen/Qwen3-Next-80B-A3B-Thinking", - "Qwen/Qwen3-Next-80B-A3B-Thinking-FP8", - "Qwen/Qwen3-0.6B-FP8", - "Qwen/Qwen3-1.7B", - "Qwen/Qwen3-1.7B-FP8", - "Qwen/Qwen3-4B", - "Qwen/Qwen3-4B-FP8", - "Qwen/Qwen3-4B-Instruct-2507", - "Qwen/Qwen3-4B-Instruct-2507-FP8", - "Qwen/Qwen3-4B-Thinking-2507", - "Qwen/Qwen3-4B-Thinking-2507-FP8", - "Qwen/Qwen3-8B", - "Qwen/Qwen3-8B-FP8", - "Qwen/Qwen3-14B", - "Qwen/Qwen3-14B-FP8", - "Qwen/Qwen3-32B", - "Qwen/Qwen3-32B-FP8", - "Qwen/Qwen3-30B-A3B", - "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", - "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", - "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8", - "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8", - "Qwen/Qwen3-235B-A22B-GPTQ-Int4", - "Qwen/Qwen2.5-0.5B-Instruct", - "Qwen/Qwen2.5-1.5B-Instruct", - "Qwen/Qwen2.5-3B-Instruct", - "Qwen/Qwen2.5-7B-Instruct", - "Qwen/Qwen2.5-14B-Instruct", - "Qwen/Qwen2.5-32B-Instruct", - "Qwen/Qwen2.5-72B-Instruct", - "nvidia/Llama-3.3-70B-Instruct-FP8", - "nvidia/Llama-3.1-70B-Instruct-FP8", - "nvidia/Llama-3.1-8B-Instruct-FP8", - "deepseek-ai/DeepSeek-V3.1", - "deepseek-ai/DeepSeek-R1", - "deepseek-ai/DeepSeek-V3", - "deepseek-ai/DeepSeek-V2", - "MiniMaxAI/MiniMax-M2", - "zai-org/GLM-4.6", -] - +# Supported model list - key: model name, value: MLX model name (same as key if no MLX variant) +MODELS = { + "Qwen/Qwen3-0.6B": "Qwen/Qwen3-0.6B", + "openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8", + "openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-4bit", + "moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit", + "moonshotai/Kimi-K2-Instruct-0905": "mlx-community/Kimi-K2-Instruct-0905-mlx-DQ3_K_M", + "Qwen/Qwen3-Next-80B-A3B-Instruct": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit", + "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit", + "Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit", + "Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit", + "Qwen/Qwen3-0.6B-FP8": "Qwen/Qwen3-0.6B", + "Qwen/Qwen3-1.7B": "Qwen/Qwen3-1.7B", + "Qwen/Qwen3-1.7B-FP8": "Qwen/Qwen3-1.7B", + "Qwen/Qwen3-4B": "Qwen/Qwen3-4B", + "Qwen/Qwen3-4B-FP8": "Qwen/Qwen3-4B", + "Qwen/Qwen3-4B-Instruct-2507": "Qwen/Qwen3-4B-Instruct-2507", + "Qwen/Qwen3-4B-Instruct-2507-FP8": "Qwen/Qwen3-4B-Instruct-2507-FP8", + "Qwen/Qwen3-4B-Thinking-2507": "Qwen/Qwen3-4B-Thinking-2507", + "Qwen/Qwen3-4B-Thinking-2507-FP8": "Qwen/Qwen3-4B-Thinking-2507-FP8", + "Qwen/Qwen3-8B": "Qwen/Qwen3-8B", + "Qwen/Qwen3-8B-FP8": "Qwen/Qwen3-8B-FP8", + "Qwen/Qwen3-14B": "Qwen/Qwen3-14B", + "Qwen/Qwen3-14B-FP8": "Qwen/Qwen3-14B-FP8", + "Qwen/Qwen3-32B": "Qwen/Qwen3-32B", + "Qwen/Qwen3-32B-FP8": "Qwen/Qwen3-32B-FP8", + "Qwen/Qwen3-30B-A3B": "Qwen/Qwen3-30B-A3B", + "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8": "Qwen/Qwen3-30B-A3B-Instruct-2507-FP8", + "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8": "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", + "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit", + "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-4bit", + "Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit", + "Qwen/Qwen2.5-0.5B-Instruct": "Qwen/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct": "Qwen/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-3B-Instruct": "Qwen/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-14B-Instruct": "Qwen/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-32B-Instruct": "Qwen/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-72B-Instruct": "Qwen/Qwen2.5-72B-Instruct", + "nvidia/Llama-3.3-70B-Instruct-FP8": "nvidia/Llama-3.3-70B-Instruct-FP8", + "nvidia/Llama-3.1-70B-Instruct-FP8": "nvidia/Llama-3.1-70B-Instruct-FP8", + "nvidia/Llama-3.1-8B-Instruct-FP8": "nvidia/Llama-3.1-8B-Instruct-FP8", + "deepseek-ai/DeepSeek-V3.1": "deepseek-ai/DeepSeek-V3.1", + "deepseek-ai/DeepSeek-R1": "deepseek-ai/DeepSeek-R1", + "deepseek-ai/DeepSeek-V3": "deepseek-ai/DeepSeek-V3", + "deepseek-ai/DeepSeek-V2": "deepseek-ai/DeepSeek-V2", + "MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit", + "zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit", +} + +logger = logging.getLogger(__name__) NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """ @@ -76,6 +78,8 @@ def _load_config_only(name: str) -> dict: config = _load_config_only(model_name) # get quant method + # logger.info(f"Loading model config from {model_name}") + quant_method = config.get("quant_method", None) quantization_config = config.get("quantization_config", None) if quant_method is None and quantization_config is not None: @@ -88,9 +92,14 @@ def _load_config_only(name: str) -> dict: elif quant_method in ("mxfp4", "int4", "awq", "gptq"): param_bytes_per_element = 0.5 - # Only for hack, fix it when support different quantization bits - # if "minimax-m2" in model_name.lower(): - # param_bytes_per_element = 0.5 + mlx_param_bytes_per_element = param_bytes_per_element + mlx_model_name = MODELS.get(model_name, model_name) + + if mlx_model_name != model_name: + mlx_config = _load_config_only(mlx_model_name) + mlx_quant_dict = mlx_config.get("quantization_config", None) + if mlx_quant_dict and "bits" in mlx_quant_dict: + mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8 # get local experts num_local_experts = config.get("num_local_experts", None) @@ -101,6 +110,7 @@ def _load_config_only(name: str) -> dict: model_info = ModelInfo( model_name=model_name, + mlx_model_name=mlx_model_name, head_size=config.get("head_dim", 128), qk_nope_head_dim=config.get("qk_nope_head_dim", None), qk_rope_head_dim=config.get("qk_rope_head_dim", None), @@ -112,6 +122,7 @@ def _load_config_only(name: str) -> dict: num_layers=config.get("num_hidden_layers", 0), ffn_num_projections=3, param_bytes_per_element=param_bytes_per_element, + mlx_param_bytes_per_element=mlx_param_bytes_per_element, cache_bytes_per_element=2, embedding_bytes_per_element=2, num_local_experts=num_local_experts, @@ -122,7 +133,7 @@ def _load_config_only(name: str) -> dict: def get_model_list(): - return MODEL_LIST + return list(MODELS.keys()) def get_node_join_command(scheduler_addr, is_local_network): diff --git a/src/parallax/launch.py b/src/parallax/launch.py index e4443c23..be706465 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -23,26 +23,11 @@ from parallax.server.executor import Executor from parallax.server.http_server import launch_http_server from parallax.server.server_args import parse_args -from parallax.utils.utils import get_current_device from parallax_utils.ascii_anime import display_parallax_join from parallax_utils.logging_config import get_logger, set_log_level logger = get_logger("parallax.launch") -"""Currently hard code model name for MAC""" -MLX_MODEL_NAME_MAP = { - "openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8", - "openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-4bit", - "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit", - "Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit", - "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit", - "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-4bit", - "Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit", - "moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit", - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", - "MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit", - "zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit", -} if __name__ == "__main__": multiprocessing.set_start_method("spawn", force=True) @@ -64,12 +49,6 @@ logger.debug(f"executor_input_addr: {args.executor_input_ipc}") logger.debug(f"executor_output_addr: {args.executor_output_ipc}") - # Hard code for mlx-community models - if get_current_device() == "mlx": - mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None) - if mlx_model_repo is not None: - args.model_path = mlx_model_repo - logger.debug(f"Replace mlx model path: {mlx_model_repo}") if args.scheduler_addr is None: if args.log_level != "DEBUG": display_parallax_join(args.model_path) @@ -122,11 +101,6 @@ args.end_layer = gradient_server.block_end_index args.model_path = gradient_server.model_name # Hard code for mlx-community models - if get_current_device() == "mlx": - mlx_model_repo = MLX_MODEL_NAME_MAP.get(args.model_path, None) - if mlx_model_repo is not None: - args.model_path = mlx_model_repo - logger.debug(f"Replace mlx model path: {mlx_model_repo}") logger.debug( f"Start Executor with start_layer: {args.start_layer}, end_layer: {args.end_layer}" ) diff --git a/src/parallax/server/server_info.py b/src/parallax/server/server_info.py index 4d056c42..e83d675e 100644 --- a/src/parallax/server/server_info.py +++ b/src/parallax/server/server_info.py @@ -182,6 +182,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: "gpu_name": "Unknown", "memory_gb": 16.0, "memory_bandwidth_gbps": 100.0, + "device": "Unknown", } if isinstance(hw, NvidiaHardwareInfo): @@ -191,6 +192,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: "gpu_name": hw.chip, "memory_gb": hw.vram_gb, "memory_bandwidth_gbps": hw.memory_bandwidth_gbps, + "device": "cuda", } if isinstance(hw, AppleSiliconHardwareInfo): # Use unified memory size as memory_gb; bandwidth rough estimate per family @@ -201,6 +203,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: "gpu_name": hw.chip, "memory_gb": hw.total_ram_gb, "memory_bandwidth_gbps": est_bandwidth, + "device": "mlx", } # Generic fallback return { @@ -209,6 +212,7 @@ def detect_node_hardware(node_id: Optional[str]) -> Dict[str, Any]: "gpu_name": "Unknown", "memory_gb": 16.0, "memory_bandwidth_gbps": 100.0, + "device": "Unknown", } diff --git a/src/scheduling/model_info.py b/src/scheduling/model_info.py index 5d6cd9f2..b79d9e45 100644 --- a/src/scheduling/model_info.py +++ b/src/scheduling/model_info.py @@ -23,6 +23,7 @@ class ModelInfo: """ model_name: str + mlx_model_name: str head_size: int hidden_dim: int intermediate_dim: int @@ -37,6 +38,7 @@ class ModelInfo: tie_embedding: bool = False # Default int8 param_bytes_per_element: float = 1 + mlx_param_bytes_per_element: float = 1 cache_bytes_per_element: int = 1 embedding_bytes_per_element: int = 1 @@ -70,6 +72,10 @@ def k_dim(self) -> int: """Return key head dim.""" return self.num_kv_heads * self.head_size_k + @property + def mlx_bit_factor(self) -> float: + return self.mlx_param_bytes_per_element / self.param_bytes_per_element + @property def embedding_io_bytes(self) -> int: """Estimate memory for input_embeddings / or lm_head.""" diff --git a/src/scheduling/node.py b/src/scheduling/node.py index 5824ddae..d33276c7 100644 --- a/src/scheduling/node.py +++ b/src/scheduling/node.py @@ -35,6 +35,7 @@ class NodeHardwareInfo: gpu_name: str memory_gb: float memory_bandwidth_gbps: float + device: str @dataclass @@ -294,9 +295,19 @@ def get_decoder_layer_capacity( available_memory_bytes, self.model_info.decoder_layer_io_bytes(roofline=False), ) - return floor( - available_memory_bytes / self.model_info.decoder_layer_io_bytes(roofline=False) - ) + if self.hardware.device == "mlx": + # For mlx, consider mlx bit factor + return floor( + available_memory_bytes + / ( + self.model_info.decoder_layer_io_bytes(roofline=False) + * self.model_info.mlx_bit_factor + ) + ) + else: + return floor( + available_memory_bytes / self.model_info.decoder_layer_io_bytes(roofline=False) + ) @property def per_decoder_layer_kv_cache_memory(self) -> Optional[int]: diff --git a/tests/scheduler_tests/test_layer_allocation.py b/tests/scheduler_tests/test_layer_allocation.py index 780a0de8..972950f0 100644 --- a/tests/scheduler_tests/test_layer_allocation.py +++ b/tests/scheduler_tests/test_layer_allocation.py @@ -26,10 +26,10 @@ def _build_node(gpu_type: str, model: ModelInfo, id_suffix: str = "") -> Node: hw_map = { - "a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0), - "a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0), - "rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0), - "rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0), + "a100-80g": NodeHardwareInfo("a100-80g" + id_suffix, 312.0, "", 80.0, 2039.0, "cuda"), + "a100-40g": NodeHardwareInfo("a100-40g" + id_suffix, 312.0, "", 40.0, 1935.0, "cuda"), + "rtx5090": NodeHardwareInfo("rtx5090" + id_suffix, 165, "", 32.0, 1792.0, "cuda"), + "rtx4090": NodeHardwareInfo("rtx4090" + id_suffix, 82.6, "", 24.0, 1008.0, "cuda"), } hw = hw_map[gpu_type] return Node(node_id=hw.node_id, hardware=hw, model_info=model) diff --git a/tests/scheduler_tests/test_scheduler.py b/tests/scheduler_tests/test_scheduler.py index 2c8f6083..e8e4ff53 100644 --- a/tests/scheduler_tests/test_scheduler.py +++ b/tests/scheduler_tests/test_scheduler.py @@ -18,6 +18,7 @@ def _build_node(node_id: str, model: ModelInfo, *, tflops: float, mem_gb: float) gpu_name="", memory_gb=mem_gb, memory_bandwidth_gbps=1000.0, + device="cuda", ) n = Node(node_id=node_id, hardware=hw, model_info=model) # Ensure latency estimation uses a defined speedup diff --git a/tests/scheduler_tests/test_utils.py b/tests/scheduler_tests/test_utils.py index d888b88c..47ffcfe1 100644 --- a/tests/scheduler_tests/test_utils.py +++ b/tests/scheduler_tests/test_utils.py @@ -11,16 +11,36 @@ from scheduling.node import Node, NodeHardwareInfo A100_80G = NodeHardwareInfo( - node_id="a100-80g", tflops_fp16=312.0, gpu_name="", memory_gb=80.0, memory_bandwidth_gbps=2039 + node_id="a100-80g", + tflops_fp16=312.0, + gpu_name="", + memory_gb=80.0, + memory_bandwidth_gbps=2039, + device="cuda", ) A100_40G = NodeHardwareInfo( - node_id="a100-40g", tflops_fp16=312.0, gpu_name="", memory_gb=40.0, memory_bandwidth_gbps=1935 + node_id="a100-40g", + tflops_fp16=312.0, + gpu_name="", + memory_gb=40.0, + memory_bandwidth_gbps=1935, + device="cuda", ) RTX5090 = NodeHardwareInfo( - node_id="rtx5090", tflops_fp16=104.8, gpu_name="", memory_gb=32.0, memory_bandwidth_gbps=1792 + node_id="rtx5090", + tflops_fp16=104.8, + gpu_name="", + memory_gb=32.0, + memory_bandwidth_gbps=1792, + device="cuda", ) RTX4090 = NodeHardwareInfo( - node_id="rtx4090", tflops_fp16=82.6, gpu_name="", memory_gb=24.0, memory_bandwidth_gbps=1008 + node_id="rtx4090", + tflops_fp16=82.6, + gpu_name="", + memory_gb=24.0, + memory_bandwidth_gbps=1008, + device="cuda", ) @@ -28,6 +48,7 @@ def build_model_info(num_layers: int) -> ModelInfo: """Build a model config used across tests (matches allocation tests).""" return ModelInfo( model_name=f"GPUOss-{num_layers}L", + mlx_model_name=f"MLXOss-{num_layers}L", head_size=64, hidden_dim=2880, intermediate_dim=2880, @@ -39,6 +60,7 @@ def build_model_info(num_layers: int) -> ModelInfo: num_local_experts=128, num_experts_per_tok=4, param_bytes_per_element=1, + mlx_param_bytes_per_element=1, cache_bytes_per_element=2, embedding_bytes_per_element=2, ) @@ -60,6 +82,7 @@ def build_node( gpu_name="", memory_gb=mem_gb, memory_bandwidth_gbps=mem_bandwidth_gbps, + device="cuda", ) n = Node(node_id=node_id, hardware=hw, model_info=model, _force_max_concurrent_requests=True) # Attach coordinates for RTT synthesis in tests