Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-V2",
"MiniMaxAI/MiniMax-M2",
"zai-org/GLM-4.6",
]

NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join"""
Expand Down
1 change: 1 addition & 0 deletions src/parallax/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit",
"deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx",
"MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit",
"zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit",
}

if __name__ == "__main__":
Expand Down
92 changes: 92 additions & 0 deletions src/parallax/models/glm4_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Optional, Tuple

import mlx.core as mx
from mlx_lm.models.base import scaled_dot_product_attention
from mlx_lm.models.glm4_moe import Attention as MLXGLM4MoeAttention
from mlx_lm.models.glm4_moe import DecoderLayer as MLXGLM4MoeBlock
from mlx_lm.models.glm4_moe import ModelArgs


class ParallaxGLM4MoeAttention(MLXGLM4MoeAttention):
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
B, L, D = x.shape

queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)

queries = queries.reshape(B, L, self.n_heads, -1)
keys = keys.reshape(B, L, self.n_kv_heads, -1)

if self.use_qk_norm:
queries = self.q_norm(queries)
keys = self.k_norm(keys)

queries = queries.transpose(0, 2, 1, 3)
keys = keys.transpose(0, 2, 1, 3)
values_new = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

queries_rotated = self.rope(queries, offset=offset)
keys_rotated = self.rope(keys, offset=offset)

if cache is not None:
past_k, past_v = cache
if past_k is not None and past_v is not None:
if past_k.shape[2] != offset:
raise ValueError(
f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} "
f"to match RoPE offset {offset} (S_past_padded)."
)
final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2)
final_values_for_attn = mx.concatenate([past_v, values_new], axis=2)
else:
raise ValueError("cache was provided but one of k/v was None.")
else:
final_keys_for_attn = keys_rotated
final_values_for_attn = values_new

output = scaled_dot_product_attention(
queries_rotated,
final_keys_for_attn,
final_values_for_attn,
scale=self.scale,
mask=mask,
cache=None,
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (keys_rotated, values_new)


class ParallaxGLM4MoeBlock(MLXGLM4MoeBlock):

def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__(args, layer_idx)
self.self_attn = ParallaxGLM4MoeAttention(args)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
):
r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, (k_cache, v_cache)

@classmethod
def get_architecture(cls):
"""Get the architecture name for the block."""
return "Glm4MoeForCausalLM"


EntryClass = ParallaxGLM4MoeBlock
7 changes: 7 additions & 0 deletions src/parallax/sglang/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,12 @@ def monkey_patch_minimax_m2_model():
apply_minimax_m2_monkey_patch()


def monkey_patch_glm4_moe_model():
from parallax.sglang.monkey_patch.glm4_moe_model import apply_glm4_moe_monkey_patch

apply_glm4_moe_monkey_patch()


def form_sgl_server_args(
model_path: str,
dtype: str = "bfloat16",
Expand Down Expand Up @@ -529,6 +535,7 @@ def apply_parallax_monkey_patch():
monkey_patch_gpt_oss()
monkey_patch_triton_backend_init()
monkey_patch_minimax_m2_model()
monkey_patch_glm4_moe_model()


def initialize_sgl_model_runner(
Expand Down
193 changes: 193 additions & 0 deletions src/parallax/sglang/monkey_patch/glm4_moe_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import logging
from typing import Iterable, Optional, Tuple

import torch
from sglang.srt.distributed import get_pp_group
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader

logger = logging.getLogger(__name__)


def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
"""Load model weights with proper mapping for GLM4 Moe architecture."""
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
# compatible with old design
nextn_layer_id = (
0 if self.config.num_hidden_layers == 1 else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")

stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]

expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts,
)

if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"shared_head.norm",
"eh_proj",
"enorm",
"hnorm",
]

params_dict = dict(self.named_parameters())
weight_names = []
for name, loaded_weight in weights:
if "lm_head" in name:
pp_group = getattr(self, "pp_group", None) or get_pp_group()
if not pp_group.is_last_rank:
logger.debug("Skipping lm_head weight '%s' on non-last PP rank", name)
continue
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer)
):
continue
weight_names.append(name)

if not is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
if num_nextn_layers > 0 and name.startswith("model.layers"):
name_list = name.split(".")
if len(name_list) >= 3 and int(name_list[2]) >= self.config.num_hidden_layers:
continue
else:
if not name.startswith(nextn_layer_prefix):
continue

# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue

is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")

if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False

for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue

# Mark as expert weight regardless of whether we can process it
is_expert_weight = True

name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue

param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue

# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue

if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")


def apply_glm4_moe_monkey_patch():
"""Apply monkey patches to GLM4 Moe for PP support and weight loading."""
import sglang.srt.models.glm4_moe as glm4_moe_module

def pp_forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch,
inputs_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
**kwargs,
):
hidden_states = self.model(
input_ids, positions, forward_batch, inputs_embeds, pp_proxy_tensors
)

if isinstance(hidden_states, PPProxyTensors):
return hidden_states

pp_group = getattr(self, "pp_group", None) or get_pp_group()
if pp_group.is_last_rank:
return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch)
else:
return hidden_states

glm4_moe_module.Glm4MoeForCausalLM.forward = pp_forward
glm4_moe_module.Glm4MoeForCausalLM.load_weights = monkey_patch_load_weights