diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 39dd184a..d2691c39 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -50,6 +50,7 @@ "deepseek-ai/DeepSeek-V3", "deepseek-ai/DeepSeek-V2", "MiniMaxAI/MiniMax-M2", + "zai-org/GLM-4.6", ] NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 7627388d..e4443c23 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -41,6 +41,7 @@ "moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", "MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit", + "zai-org/GLM-4.6": "mlx-community/GLM-4.6-4bit", } if __name__ == "__main__": diff --git a/src/parallax/models/glm4_moe.py b/src/parallax/models/glm4_moe.py new file mode 100644 index 00000000..f35b2955 --- /dev/null +++ b/src/parallax/models/glm4_moe.py @@ -0,0 +1,92 @@ +from typing import Optional, Tuple + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.glm4_moe import Attention as MLXGLM4MoeAttention +from mlx_lm.models.glm4_moe import DecoderLayer as MLXGLM4MoeBlock +from mlx_lm.models.glm4_moe import ModelArgs + + +class ParallaxGLM4MoeAttention(MLXGLM4MoeAttention): + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.n_heads, -1) + keys = keys.reshape(B, L, self.n_kv_heads, -1) + + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + queries = queries.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values_new = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + queries_rotated = self.rope(queries, offset=offset) + keys_rotated = self.rope(keys, offset=offset) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) + final_values_for_attn = mx.concatenate([past_v, values_new], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = keys_rotated + final_values_for_attn = values_new + + output = scaled_dot_product_attention( + queries_rotated, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys_rotated, values_new) + + +class ParallaxGLM4MoeBlock(MLXGLM4MoeBlock): + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args, layer_idx) + self.self_attn = ParallaxGLM4MoeAttention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "Glm4MoeForCausalLM" + + +EntryClass = ParallaxGLM4MoeBlock diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index e488951b..0692dd34 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -496,6 +496,12 @@ def monkey_patch_minimax_m2_model(): apply_minimax_m2_monkey_patch() +def monkey_patch_glm4_moe_model(): + from parallax.sglang.monkey_patch.glm4_moe_model import apply_glm4_moe_monkey_patch + + apply_glm4_moe_monkey_patch() + + def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -529,6 +535,7 @@ def apply_parallax_monkey_patch(): monkey_patch_gpt_oss() monkey_patch_triton_backend_init() monkey_patch_minimax_m2_model() + monkey_patch_glm4_moe_model() def initialize_sgl_model_runner( diff --git a/src/parallax/sglang/monkey_patch/glm4_moe_model.py b/src/parallax/sglang/monkey_patch/glm4_moe_model.py new file mode 100644 index 00000000..233e000d --- /dev/null +++ b/src/parallax/sglang/monkey_patch/glm4_moe_model.py @@ -0,0 +1,193 @@ +import logging +from typing import Iterable, Optional, Tuple + +import torch +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.model_loader.weight_utils import default_weight_loader + +logger = logging.getLogger(__name__) + + +def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + """Load model weights with proper mapping for GLM4 Moe architecture.""" + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + # compatible with old design + nextn_layer_id = ( + 0 if self.config.num_hidden_layers == 1 else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names = [ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + ] + + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + if "lm_head" in name: + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if not pp_group.is_last_rank: + logger.debug("Skipping lm_head weight '%s' on non-last PP rank", name) + continue + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + weight_names.append(name) + + if not is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + if num_nextn_layers > 0 and name.startswith("model.layers"): + name_list = name.split(".") + if len(name_list) >= 3 and int(name_list[2]) >= self.config.num_hidden_layers: + continue + else: + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Track if this is an expert weight to enable early skipping + is_expert_weight = False + + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + + # Mark as expert weight regardless of whether we can process it + is_expert_weight = True + + name = name.replace(weight_name, param_name) + if name not in params_dict: + # Expert weight not on this rank, will be skipped below + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + if is_expert_weight: + # This is an expert weight but not mapped to this rank, skip all remaining processing + continue + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + +def apply_glm4_moe_monkey_patch(): + """Apply monkey patches to GLM4 Moe for PP support and weight loading.""" + import sglang.srt.models.glm4_moe as glm4_moe_module + + def pp_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch, + inputs_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, forward_batch, inputs_embeds, pp_proxy_tensors + ) + + if isinstance(hidden_states, PPProxyTensors): + return hidden_states + + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if pp_group.is_last_rank: + return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) + else: + return hidden_states + + glm4_moe_module.Glm4MoeForCausalLM.forward = pp_forward + glm4_moe_module.Glm4MoeForCausalLM.load_weights = monkey_patch_load_weights