Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"openai/gpt-oss-120b",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8",
# "Qwen/Qwen3-8B",
# "Qwen/Qwen3-8B-FP8",
"Qwen/Qwen3-32B",
Expand Down Expand Up @@ -60,6 +60,13 @@ def get_model_info(model_name):
elif quant_method in ("mxfp4", "int4", "awq", "gptq"):
param_bytes_per_element = 0.5

# get local experts
num_local_experts = config.get("num_local_experts", None)
if num_local_experts is None:
num_local_experts = config.get("num_experts", None)
if num_local_experts is None:
num_local_experts = config.get("n_routed_experts", None)

model_info = ModelInfo(
model_name=model_name,
head_size=config.get("head_dim", 128),
Expand All @@ -75,8 +82,9 @@ def get_model_info(model_name):
param_bytes_per_element=param_bytes_per_element,
cache_bytes_per_element=2,
embedding_bytes_per_element=2,
num_local_experts=config.get("num_experts", None),
num_local_experts=num_local_experts,
num_experts_per_tok=config.get("num_experts_per_tok", None),
moe_intermediate_dim=config.get("moe_intermediate_size", None),
)
return model_info

Expand Down
5 changes: 5 additions & 0 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
MLX_MODEL_NAME_MAP = {
"openai/gpt-oss-20b": "mlx-community/gpt-oss-20b-MXFP4-Q8",
"openai/gpt-oss-120b": "mlx-community/gpt-oss-120b-4bit",
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-4bit",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-4bit",
"moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit",
}

if __name__ == "__main__":
Expand Down
7 changes: 7 additions & 0 deletions src/scheduling/layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,13 @@ def global_allocation(self) -> bool:
total_cap,
)
return False
else:
logger.debug(
"[DP] Sufficient resources: nodes=%d, layers=%d, total_cap=%d",
num_nodes,
num_layers,
total_cap,
)
# used for pruning
suffix_sum = [0] * (num_nodes + 1)
for i in range(num_nodes - 1, -1, -1):
Expand Down
32 changes: 24 additions & 8 deletions src/scheduling/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from dataclasses import dataclass
from typing import Optional

from parallax_utils.logging_config import get_logger

logger = get_logger(__name__)


@dataclass
class ModelInfo:
Expand All @@ -29,6 +33,7 @@ class ModelInfo:
ffn_num_projections: int = 3
num_local_experts: Optional[int] = None
num_experts_per_tok: Optional[int] = None
moe_intermediate_dim: Optional[int] = None
tie_embedding: bool = False
# Default int8
param_bytes_per_element: float = 1
Expand All @@ -50,6 +55,11 @@ def __init__(self, **kwargs):
self.head_size_k = self.head_size
self.head_size_v = self.head_size

@property
def q_dim(self) -> int:
"""Return query head dim."""
return self.num_attention_heads * self.head_size

@property
def v_dim(self) -> int:
"""Return key and value head dim."""
Expand Down Expand Up @@ -143,17 +153,17 @@ def decoder_layer_io_bytes(
source_seq_len: Source sequence length (prompt tokens)
"""
# Attention params
qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim
kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim) // 2
qo_params = self.param_bytes_per_element * self.hidden_dim * self.q_dim * 2
kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim)
attention_params = qo_params + kv_params

# FFN params
ffn_params = (
self.param_bytes_per_element
* self.ffn_num_projections
* self.hidden_dim
* self.intermediate_dim
)
ffn_params = self.param_bytes_per_element * self.ffn_num_projections * self.hidden_dim
if self.moe_intermediate_dim is not None:
ffn_params *= self.moe_intermediate_dim
else:
ffn_params *= self.intermediate_dim

if roofline:
expected_experts = self.expected_num_activated_experts(
batch_size=batch_size, target_seq_len=target_seq_len
Expand All @@ -168,6 +178,12 @@ def decoder_layer_io_bytes(
ffn_params *= self.num_local_experts
kv_cache_size = 0

logger.debug(
"Model Info ffn_params=%d, kv_cache_size=%d, attention_params=%d",
ffn_params,
kv_cache_size,
attention_params,
)
return round(ffn_params + kv_cache_size + attention_params)

def lm_head_flops(self, target_seq_len: int = 1) -> int:
Expand Down
5 changes: 5 additions & 0 deletions src/scheduling/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ def get_decoder_layer_capacity(
if not (include_input_embed and self.model_info.tie_embedding):
available_memory_bytes -= self.model_info.embedding_io_bytes

logger.debug(
"Node available_memory_bytes=%d, decoder_layer_io_bytes=%d",
available_memory_bytes,
self.model_info.decoder_layer_io_bytes(roofline=False),
)
return floor(
available_memory_bytes / self.model_info.decoder_layer_io_bytes(roofline=False)
)
Expand Down
16 changes: 8 additions & 8 deletions tests/scheduler_tests/test_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ def test_capacity_sanity_check():
"num_layers,gpu_types,expected_layers",
[
(21, ["a100-80g", "rtx5090", "rtx4090"], [13, 5, 3]),
(15, ["a100-80g", "rtx5090"], [10, 5]),
(15, ["a100-80g", "rtx5090"], [11, 4]),
# (20 * 312 : 20 * 165 : 20 * 82.6) / 559.6 = 11.1 : 5.8 : 2.9 -> 12 : 5 : 3
(20, ["a100-80g", "rtx5090", "rtx4090"], [12, 5, 3]),
(25, ["a100-80g", "rtx5090", "rtx4090", "rtx4090"], [13, 5, 4, 3]),
(29, ["rtx4090", "a100-80g", "rtx5090", "rtx5090", "rtx4090"], [3, 13, 5, 5, 3]),
(9, ["rtx5090", "rtx5090"], [5, 4]),
(8, ["rtx5090", "rtx5090"], [4, 4]),
(7, ["a100-40g", "rtx5090"], [5, 2]),
],
)
Expand Down Expand Up @@ -155,25 +155,25 @@ def _test_gap_patch_rebalance(allocator: BaseLayerAllocator):
],
"dp",
),
# 14 Layers, capacity (13, 5, 5, 3, 3) -> greedy assigns (9, 5)
# 14 Layers, capacity (13, 5, 5, 3, 3) -> greedy assigns (10, 4)
(
14,
(1, 0, 2, 2),
[
(0, 9),
(9, 14),
(0, 10),
(10, 14),
],
"greedy",
),
# 7 Layers, capacity (6, 5, 5, 3, 3) -> greedy assigns (5, 2, 5, 2)
# 7 Layers, capacity (6, 5, 5, 3, 3) -> greedy assigns (5, 2, 4, 3)
(
7,
(0, 1, 2, 2),
[
(0, 5),
(5, 7),
(0, 5),
(5, 7),
(0, 4),
(4, 7),
],
"greedy",
),
Expand Down