From 6dd4499100fa8a81a8f30b1ab0c8507fab2d629a Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Mon, 27 Oct 2025 11:35:41 +0000 Subject: [PATCH 01/25] add gpu support --- src/parallax/sglang/model_runner.py | 9 + .../sglang/monkey_patch/minimax_m2_model.py | 218 ++++++++++++++++++ 2 files changed, 227 insertions(+) create mode 100644 src/parallax/sglang/monkey_patch/minimax_m2_model.py diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 37c075aa..e488951b 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -488,6 +488,14 @@ def monkey_patch_triton_backend_init(): apply_triton_backend_init_monkey_patch() +def monkey_patch_minimax_m2_model(): + from parallax.sglang.monkey_patch.minimax_m2_model import ( + apply_minimax_m2_monkey_patch, + ) + + apply_minimax_m2_monkey_patch() + + def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -520,6 +528,7 @@ def apply_parallax_monkey_patch(): monkey_patch_qwen3_next() monkey_patch_gpt_oss() monkey_patch_triton_backend_init() + monkey_patch_minimax_m2_model() def initialize_sgl_model_runner( diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch/minimax_m2_model.py new file mode 100644 index 00000000..9c482576 --- /dev/null +++ b/src/parallax/sglang/monkey_patch/minimax_m2_model.py @@ -0,0 +1,218 @@ +from sglang.srt.layers.utils import get_layer_id +import logging +from typing import Iterable, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import ( + get_moe_expert_parallel_world_size, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.communicator import ( + LayerCommunicator, + LayerScatterModes, + ScatterMode, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import get_moe_a2a_backend +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.server_args import get_global_server_args +from sglang.srt.two_batch_overlap import model_forward_maybe_tbo +from sglang.srt.utils import ( + BumpAllocator, + add_prefix, + get_compiler_backend, + is_non_idle_and_non_empty, + make_layers, +) +from sglang.srt.models.minimax_m2 import MiniMaxM2ForCausalLM, get_spec_layer_idx_from_weight_name + + +def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load model weights with proper mapping for MiniMax architecture.""" + + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + + params_dict = dict(self.named_parameters()) + logger = logging.getLogger(__name__) + + weight_name_map = { + "lm_head.weight": "model.embed_tokens.weight", + } + + def resolve_param(name: str): + """Resolve weight name to actual parameter, handling tied weights and PP filtering.""" + if name in weight_name_map: + mapped_name = weight_name_map[name] + if mapped_name in params_dict: + logger.debug("Mapped '%s' -> '%s' (tied weight)", name, mapped_name) + return mapped_name, params_dict[mapped_name] + + if name in params_dict: + return name, params_dict[name] + + alt = f"model.{name}" + if alt in params_dict: + return alt, params_dict[alt] + + matches = [k for k in params_dict.keys() if k.endswith(name)] + if len(matches) == 1: + return matches[0], params_dict[matches[0]] + + if name in ("model.norm.weight", "model.embed_tokens.weight"): + logger.debug("Weight '%s' not found (PP-sliced)", name) + return None, None + + if ("lm_head" in name) or ("embed" in name): + sample = [k for k in params_dict.keys() if ("lm_head" in k) or ("embed" in k)] + if not sample: + sample = list(params_dict.keys())[:50] + logger.warning("Failed to resolve '%s'. Sample params: %s", name, sample) + return None, None + + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + if name.endswith(".bias") and name not in params_dict: + continue + + resolved_name, param = resolve_param(name) + if param is None: + if name not in ("model.norm.weight", "model.embed_tokens.weight"): + logger.warning("Skipping weight '%s' (no matching parameter)", name) + continue + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + resolved_name, param = resolve_param(name) + if param is None: + if name not in ("model.norm.weight", "model.embed_tokens.weight"): + logger.warning("Skipping expert weight '%s' (no matching parameter)", name) + continue + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id) + break + else: + if name.endswith(".bias") and name not in params_dict: + continue + + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + resolved_name, param = resolve_param(name) + if param is None: + if name not in ("model.norm.weight", "model.embed_tokens.weight"): + logger.warning("Skipping weight '%s' (no matching parameter)", name) + continue + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def apply_minimax_m2_monkey_patch(): + """Apply monkey patches to MiniMax M2 for PP support and weight loading.""" + import sglang.srt.models.minimax_m2 as m2_module + + orig_init = m2_module.MiniMaxM2ForCausalLM.__init__ + + def pp_init(self, config, quant_config=None, prefix=""): + orig_init(self, config, quant_config, prefix) + self.pp_group = get_pp_group() + + def pp_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch, + inputs_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, forward_batch, inputs_embeds, pp_proxy_tensors + ) + + if isinstance(hidden_states, PPProxyTensors): + return hidden_states + + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if pp_group.is_last_rank: + return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) + else: + return hidden_states + + m2_module.MiniMaxM2ForCausalLM.__init__ = pp_init + m2_module.MiniMaxM2ForCausalLM.forward = pp_forward + m2_module.MiniMaxM2ForCausalLM.load_weights = monkey_patch_load_weights From 0f6d245a7a914c8909d5a1bb2d442c22fda8e0e8 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 27 Oct 2025 20:30:41 +0800 Subject: [PATCH 02/25] update mlx support --- src/parallax/models/minimax_m2.py | 268 ++++++++++++++++++ .../sglang/monkey_patch/minimax_m2_model.py | 55 +--- 2 files changed, 273 insertions(+), 50 deletions(-) create mode 100644 src/parallax/models/minimax_m2.py diff --git a/src/parallax/models/minimax_m2.py b/src/parallax/models/minimax_m2.py new file mode 100644 index 00000000..2550fc63 --- /dev/null +++ b/src/parallax/models/minimax_m2.py @@ -0,0 +1,268 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, List, Optional +from typing import Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +from mlx_lm.models.base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from mlx_lm.models.cache import KVCache +from mlx_lm.models.switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + max_position_embeddings: int + num_experts_per_tok: int + num_local_experts: int + shared_intermediate_size: int + num_hidden_layers: int + rms_norm_eps: float + rope_theta: float + rotary_dim: int + vocab_size: int + block_size: int = 256 + tie_word_embeddings: bool = False + shared_moe_mode: str = "sigmoid" + full_attn_alpha_factor: float = 3.5565588200778455 + full_attn_beta_factor: float = 1.0 + linear_attn_alpha_factor: float = 3.5565588200778455 + linear_attn_beta_factor: float = 1.0 + mlp_alpha_factor: float = 3.5565588200778455 + mlp_beta_factor: float = 1.0 + layer_types: List[str] = None + head_dim: Optional[int] = None + use_qk_norm: bool = True + + +class MLXMiniMaxM2Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_dim = hidden_size = args.hidden_size + + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads + self.head_dim = head_dim = ( + hidden_size // args.num_attention_heads if args.head_dim is None else args.head_dim + ) + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear( + args.hidden_size, self.num_attention_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + self.num_attention_heads * self.head_dim, args.hidden_size, bias=False + ) + + self.use_qk_norm = args.use_qk_norm + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) + + self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class MiniMaxSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts_per_tok = args.num_experts_per_tok + + self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, args.intermediate_size, args.num_local_experts + ) + self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) + + def __call__(self, x: mx.array) -> mx.array: + gates = self.gate(x.astype(mx.float32)) + + scores = mx.sigmoid(gates) + orig_scores = scores + scores = scores + self.e_score_correction_bias + + k = self.num_experts_per_tok + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(orig_scores, inds, axis=-1) + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + return y + + +class MLXMiniMaxM2Block(nn.Module): + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__() + self.mlp_alpha_factor = args.mlp_alpha_factor + self.mlp_beta_factor = args.mlp_beta_factor + + self.self_attn = MLXMiniMaxM2Attention(args) + self.attn_alpha_factor = args.full_attn_alpha_factor + self.attn_beta_factor = args.full_attn_beta_factor + + self.block_sparse_moe = MiniMaxSparseMoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = ( + self.input_layernorm(x) * self.attn_alpha_factor + + self.self_attn(x, mask, cache) * self.attn_beta_factor + ) + r = ( + self.block_sparse_moe(self.post_attention_layernorm(x)) * self.mlp_alpha_factor + + r * self.mlp_beta_factor + ) + return r + + +class ParallaxMiniMaxAttention(MLXMiniMaxM2Attention): + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + batch, target_len, _ = x.shape + + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + + queries = queries.reshape(batch, target_len, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # for batch, rope offset is not correct due to padding in batch + queries_rotated = self.rope(queries, offset=offset) + keys_rotated = self.rope(keys, offset=offset) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = keys_rotated + final_values_for_attn = values + + output = scaled_dot_product_attention( + queries_rotated, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + return self.o_proj(output), (keys_rotated, values) + + +class ParallaxMiniMaxM2Block(MLXMiniMaxM2Block): + """A custom transformer block for Parallax, extending the MiniMaxM2 Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args) + self.self_attn = ParallaxMiniMaxAttention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "Qwen2ForCausalLM" + + +EntryClass = ParallaxMiniMaxM2Block diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch/minimax_m2_model.py index 9c482576..6a42ec5b 100644 --- a/src/parallax/sglang/monkey_patch/minimax_m2_model.py +++ b/src/parallax/sglang/monkey_patch/minimax_m2_model.py @@ -1,61 +1,16 @@ -from sglang.srt.layers.utils import get_layer_id import logging -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple import torch -from torch import nn -from transformers import PretrainedConfig - -from sglang.srt.distributed import ( - get_moe_expert_parallel_world_size, - get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder -from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.communicator import ( - LayerCommunicator, - LayerScatterModes, - ScatterMode, -) -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear, -) -from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import TopK -from sglang.srt.layers.moe.utils import get_moe_a2a_backend -from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.layers.rotary_embedding import get_rope -from sglang.srt.layers.utils import PPMissingLayer -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) -from sglang.srt.server_args import get_global_server_args -from sglang.srt.two_batch_overlap import model_forward_maybe_tbo -from sglang.srt.utils import ( - BumpAllocator, - add_prefix, - get_compiler_backend, - is_non_idle_and_non_empty, - make_layers, -) -from sglang.srt.models.minimax_m2 import MiniMaxM2ForCausalLM, get_spec_layer_idx_from_weight_name +from sglang.srt.models.minimax_m2 import get_spec_layer_idx_from_weight_name def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): From 53b5aad60bcf2a42970830b626618b4e8aefbe3f Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 27 Oct 2025 20:33:45 +0800 Subject: [PATCH 03/25] update --- src/backend/server/static_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 1d4f3a8e..7cef11c9 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -49,6 +49,7 @@ "deepseek-ai/DeepSeek-R1", "deepseek-ai/DeepSeek-V3", "deepseek-ai/DeepSeek-V2", + "MiniMaxAI/MiniMax-M2", ] NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" From cfeab39d5405ef4dda81926efe83308677611a4b Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 27 Oct 2025 20:39:43 +0800 Subject: [PATCH 04/25] update --- src/parallax/models/minimax_m2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/models/minimax_m2.py b/src/parallax/models/minimax_m2.py index 2550fc63..fc3123a5 100644 --- a/src/parallax/models/minimax_m2.py +++ b/src/parallax/models/minimax_m2.py @@ -262,7 +262,7 @@ def __call__( @classmethod def get_architecture(cls): """Get the architecture name for the block.""" - return "Qwen2ForCausalLM" + return "MiniMaxM2CausalLM" EntryClass = ParallaxMiniMaxM2Block From c59dadb2eca1383e90a6631853de50b0acdf31a5 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 27 Oct 2025 20:41:03 +0800 Subject: [PATCH 05/25] update --- src/parallax/models/minimax_m2.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/parallax/models/minimax_m2.py b/src/parallax/models/minimax_m2.py index fc3123a5..9aa7268d 100644 --- a/src/parallax/models/minimax_m2.py +++ b/src/parallax/models/minimax_m2.py @@ -1,14 +1,11 @@ # Copyright © 2025 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional -from typing import Optional, Tuple +from typing import Any, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn - -from mlx_lm.models.base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention -from mlx_lm.models.cache import KVCache +from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention from mlx_lm.models.switch_layers import SwitchGLU From c7a2285bb19d5594acd4258e162faddb3b0c286f Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Mon, 27 Oct 2025 21:12:07 +0800 Subject: [PATCH 06/25] update model_type Map --- src/parallax/server/shard_loader.py | 30 ++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index f15250d6..f7cde070 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -18,6 +18,11 @@ logger = get_logger(__name__) +MODEL_TYPE_MAP = { + "kimi_k2": "deepseek_v3", + "minimax": "minimax_m2", +} + class MLXModelLoader: """ @@ -79,6 +84,22 @@ def register_block_class(self): except Exception as e: logger.warning(f"Failed to load model from {model_file}: {e}") + def _create_default_model_args(self, config: Dict[str, Any]) -> Any: + """Create default model arguments from config.""" + model_args = { + "hidden_size": config.get("hidden_size", 0), + "num_attention_heads": config.get("num_attention_heads", 0), + "num_key_value_heads": config.get("num_key_value_heads", 0), + "num_hidden_layers": config.get("num_hidden_layers", 0), + "intermediate_size": config.get("intermediate_size", 0), + "vocab_size": config.get("vocab_size", 0), + "head_dim": config.get("head_dim", 128), + "num_local_experts": config.get("num_local_experts", None), + "num_experts_per_tok": config.get("num_experts_per_tok", None), + "moe_intermediate_size": config.get("moe_intermediate_size", None), + } + return type("ModelArgs", (), model_args)() + def load( self, lazy: bool = False, strict: bool = True ) -> Tuple[nn.Module, Dict[str, Any], Any]: @@ -115,14 +136,17 @@ def load( # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") - if model_type == "kimi_k2": - model_type = "deepseek_v3" + if model_type in MODEL_TYPE_MAP: + model_type = MODEL_TYPE_MAP[model_type] + if not model_type: raise ValueError("model_type not found in config.json") try: - arch_module = importlib.import_module(f"mlx_lm.models.{model_type}") + # Import from project's models directory + arch_module = importlib.import_module(f"parallax.models.{model_type}") model_args_class = getattr(arch_module, "ModelArgs") model_args = model_args_class.from_dict(config) + except (ImportError, AttributeError) as e: raise ValueError(f"Failed to load architecture for model_type '{model_type}'.") from e From 6e63edf0e147ae0d39d81b74bac57864b15032a6 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 10:05:05 +0800 Subject: [PATCH 07/25] update model --- pyproject.toml | 4 +- src/parallax/models/minimax.py | 111 ++++++++++++ src/parallax/models/minimax_m2.py | 265 ---------------------------- src/parallax/server/shard_loader.py | 5 +- 4 files changed, 114 insertions(+), 271 deletions(-) create mode 100644 src/parallax/models/minimax.py delete mode 100644 src/parallax/models/minimax_m2.py diff --git a/pyproject.toml b/pyproject.toml index 1837274f..70e42c58 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,8 +41,8 @@ parallax = "parallax.cli:main" mac = [ "torch==2.8.0", - "mlx-lm==0.28.0", - "mlx==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git", + "mlx==0.29.3", ] gpu = [ diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py new file mode 100644 index 00000000..56ff913b --- /dev/null +++ b/src/parallax/models/minimax.py @@ -0,0 +1,111 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention +from mlx_lm.models.minimax import ModelArgs +from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock + + +class ParallaxMiniMaxAttention(MLXMiniMaxAttention): + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + batch, target_len, _ = x.shape + + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + + queries = queries.reshape(batch, target_len, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + # for batch, rope offset is not correct due to padding in batch + queries_rotated = self.rope(queries, offset=offset) + keys_rotated = self.rope(keys, offset=offset) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = keys_rotated + final_values_for_attn = values + + output = scaled_dot_product_attention( + queries_rotated, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + return self.o_proj(output), (keys_rotated, values) + + +class ParallaxMiniMaxBlock(MLXMiniMaxBlock): + """A custom transformer block for Parallax, extending the MiniMax Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args) + self.self_attn = ParallaxMiniMaxAttention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "MiniMaxM2CausalLM" + + +EntryClass = ParallaxMiniMaxBlock diff --git a/src/parallax/models/minimax_m2.py b/src/parallax/models/minimax_m2.py deleted file mode 100644 index 9aa7268d..00000000 --- a/src/parallax/models/minimax_m2.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright © 2025 Apple Inc. - -from dataclasses import dataclass -from typing import Any, List, Optional, Tuple - -import mlx.core as mx -import mlx.nn as nn -from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention -from mlx_lm.models.switch_layers import SwitchGLU - - -@dataclass -class ModelArgs(BaseModelArgs): - model_type: str - hidden_size: int - intermediate_size: int - num_attention_heads: int - num_key_value_heads: int - max_position_embeddings: int - num_experts_per_tok: int - num_local_experts: int - shared_intermediate_size: int - num_hidden_layers: int - rms_norm_eps: float - rope_theta: float - rotary_dim: int - vocab_size: int - block_size: int = 256 - tie_word_embeddings: bool = False - shared_moe_mode: str = "sigmoid" - full_attn_alpha_factor: float = 3.5565588200778455 - full_attn_beta_factor: float = 1.0 - linear_attn_alpha_factor: float = 3.5565588200778455 - linear_attn_beta_factor: float = 1.0 - mlp_alpha_factor: float = 3.5565588200778455 - mlp_beta_factor: float = 1.0 - layer_types: List[str] = None - head_dim: Optional[int] = None - use_qk_norm: bool = True - - -class MLXMiniMaxM2Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - self.hidden_dim = hidden_size = args.hidden_size - - self.num_attention_heads = args.num_attention_heads - self.num_key_value_heads = args.num_key_value_heads - self.head_dim = head_dim = ( - hidden_size // args.num_attention_heads if args.head_dim is None else args.head_dim - ) - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear( - args.hidden_size, self.num_attention_heads * self.head_dim, bias=False - ) - self.k_proj = nn.Linear( - args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.v_proj = nn.Linear( - args.hidden_size, self.num_key_value_heads * self.head_dim, bias=False - ) - self.o_proj = nn.Linear( - self.num_attention_heads * self.head_dim, args.hidden_size, bias=False - ) - - self.use_qk_norm = args.use_qk_norm - if self.use_qk_norm: - self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) - self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) - - self.rope = nn.RoPE(head_dim, traditional=False, base=args.rope_theta) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - if self.use_qk_norm: - queries = self.q_norm(queries) - keys = self.k_norm(keys) - - queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - queries = self.rope(queries, offset=cache.offset) - keys = self.rope(keys, offset=cache.offset) - keys, values = cache.update_and_fetch(keys, values) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) - - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - - return self.o_proj(output) - - -class MiniMaxSparseMoeBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_experts_per_tok = args.num_experts_per_tok - - self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) - self.switch_mlp = SwitchGLU( - args.hidden_size, args.intermediate_size, args.num_local_experts - ) - self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) - - def __call__(self, x: mx.array) -> mx.array: - gates = self.gate(x.astype(mx.float32)) - - scores = mx.sigmoid(gates) - orig_scores = scores - scores = scores + self.e_score_correction_bias - - k = self.num_experts_per_tok - inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] - scores = mx.take_along_axis(orig_scores, inds, axis=-1) - y = self.switch_mlp(x, inds) - y = (y * scores[..., None]).sum(axis=-2) - return y - - -class MLXMiniMaxM2Block(nn.Module): - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__() - self.mlp_alpha_factor = args.mlp_alpha_factor - self.mlp_beta_factor = args.mlp_beta_factor - - self.self_attn = MLXMiniMaxM2Attention(args) - self.attn_alpha_factor = args.full_attn_alpha_factor - self.attn_beta_factor = args.full_attn_beta_factor - - self.block_sparse_moe = MiniMaxSparseMoeBlock(args) - - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Any] = None, - ) -> mx.array: - r = ( - self.input_layernorm(x) * self.attn_alpha_factor - + self.self_attn(x, mask, cache) * self.attn_beta_factor - ) - r = ( - self.block_sparse_moe(self.post_attention_layernorm(x)) * self.mlp_alpha_factor - + r * self.mlp_beta_factor - ) - return r - - -class ParallaxMiniMaxAttention(MLXMiniMaxM2Attention): - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - offset: int = 0, - lengths: Optional[mx.array] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: - """ - Attention forward pass with explicit KV cache handling. - - Args: - x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. - mask: (batch, n_q_heads, target_len, source_len) - cache: Optional tuple (past_k, past_v). - shape: (batch, n_kv_heads, S_past_padded, head_dim) - offset: source_len_padded (scalar, used for RoPE calculation). - - Returns: - output_h: (batch, target_len, hidden_dim) - Output hidden states. - new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. - new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. - """ - batch, target_len, _ = x.shape - - queries = self.q_proj(x) - keys = self.k_proj(x) - values = self.v_proj(x) - - queries = queries.reshape(batch, target_len, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - # for batch, rope offset is not correct due to padding in batch - queries_rotated = self.rope(queries, offset=offset) - keys_rotated = self.rope(keys, offset=offset) - - if cache is not None: - past_k, past_v = cache - if past_k is not None and past_v is not None: - if past_k.shape[2] != offset: - raise ValueError( - f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " - f"to match RoPE offset {offset} (S_past_padded)." - ) - final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) - final_values_for_attn = mx.concatenate([past_v, values], axis=2) - else: - raise ValueError("cache was provided but one of k/v was None.") - else: - final_keys_for_attn = keys_rotated - final_values_for_attn = values - - output = scaled_dot_product_attention( - queries_rotated, - final_keys_for_attn, - final_values_for_attn, - scale=self.scale, - mask=mask, - cache=None, - ) - - output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) - return self.o_proj(output), (keys_rotated, values) - - -class ParallaxMiniMaxM2Block(MLXMiniMaxM2Block): - """A custom transformer block for Parallax, extending the MiniMaxM2 Block class. - This version handles the KV cache explicitly and returns new K and V states. - """ - - def __init__(self, args: ModelArgs, layer_idx: int): - super().__init__(args) - self.self_attn = ParallaxMiniMaxAttention(args) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - offset: int = 0, - lengths: Optional[mx.array] = None, - ): - r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, (k_cache, v_cache) - - @classmethod - def get_architecture(cls): - """Get the architecture name for the block.""" - return "MiniMaxM2CausalLM" - - -EntryClass = ParallaxMiniMaxM2Block diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index f7cde070..1bb38077 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -18,10 +18,7 @@ logger = get_logger(__name__) -MODEL_TYPE_MAP = { - "kimi_k2": "deepseek_v3", - "minimax": "minimax_m2", -} +MODEL_TYPE_MAP = {"kimi_k2": "deepseek_v3"} class MLXModelLoader: From 1569fa916dfe632abec846df2ceca1128b227e5f Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 10:26:48 +0800 Subject: [PATCH 08/25] update mlx-lm --- src/parallax/models/minimax.py | 2 +- src/parallax/server/http_server.py | 4 ++-- src/parallax/server/shard_loader.py | 4 ++-- src/parallax/sglang/model_runner.py | 4 ++-- tests/test_executor.py | 4 ++-- tests/test_model.py | 4 ++-- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 56ff913b..62f7e629 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -105,7 +105,7 @@ def __call__( @classmethod def get_architecture(cls): """Get the architecture name for the block.""" - return "MiniMaxM2CausalLM" + return "MiniMaxM2ForCausalLM" EntryClass = ParallaxMiniMaxBlock diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 50320c2c..42099381 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,7 +30,7 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from pydantic import BaseModel from starlette.datastructures import State @@ -105,7 +105,7 @@ def __init__( self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} # Load tokenizer for separate detokenizers - model_path, _ = get_model_path(model_path_str) + model_path = _download(model_path_str) config = load_config(model_path) self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 1bb38077..32627cf8 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,7 +10,7 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from parallax.server.model import ShardedModel from parallax.utils.tokenizer_utils import load_tokenizer @@ -112,7 +112,7 @@ def load( Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path, _ = get_model_path(self.model_path_str) + model_path = _download(self.model_path_str) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index e488951b..4445eaa1 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -12,7 +12,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import get_model_path, load_config +from mlx_lm.utils import _download, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, @@ -548,7 +548,7 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_monkey_patch() - model_path = get_model_path(original_model_path)[0] + model_path = _download(original_model_path) config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") diff --git a/tests/test_executor.py b/tests/test_executor.py index 1dd16093..8c05abca 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,7 +4,7 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest @@ -12,7 +12,7 @@ MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" -model_path = get_model_path(MODEL_REPO)[0] +model_path = _download(MODEL_REPO) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) diff --git a/tests/test_model.py b/tests/test_model.py index a62dd2db..45634d6d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.utils import get_model_path, load_model +from mlx_lm.utils import _download, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader @@ -18,7 +18,7 @@ TOTAL_LAYERS = 28 -model_path = get_model_path(REPO_ID)[0] +model_path = _download(REPO_ID) ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) From 79771fac18474d4e72337b6a5dd61d6457c336f3 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 10:30:42 +0800 Subject: [PATCH 09/25] update --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 70e42c58..fab731b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,8 +46,8 @@ mac = [ ] gpu = [ - "mlx-lm==0.28.0", - "mlx[cpu]==0.29.1", + "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git", + "mlx[cpu]==0.29.3", "sglang[all]==0.5.4.post1", ] From 947f8bcfd2d02b7c9989106818b44f98ccff5f5b Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 10:35:08 +0800 Subject: [PATCH 10/25] add model map --- src/parallax/launch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 93e18872..06d437b6 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -41,6 +41,7 @@ "Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit", "moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", + "MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit", } if __name__ == "__main__": From fc13b1327d516291ce5f9cfdb702b0c6f6435c56 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 13:23:38 +0800 Subject: [PATCH 11/25] update params name --- src/parallax/models/minimax.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 62f7e629..9c72fcb2 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -42,9 +42,13 @@ def __call__( keys = self.k_proj(x) values = self.v_proj(x) - queries = queries.reshape(batch, target_len, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(batch, target_len, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + queries = queries.reshape(batch, target_len, self.num_attention_heads, -1).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(batch, target_len, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(batch, target_len, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) # for batch, rope offset is not correct due to padding in batch queries_rotated = self.rope(queries, offset=offset) From c0d4f9bcc9035c08a723a7051fd48c96684e7731 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 13:32:43 +0800 Subject: [PATCH 12/25] update name --- src/parallax/models/minimax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index 9c72fcb2..c31e79f2 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -102,7 +102,7 @@ def __call__( ): r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) h = x + r - r = self.mlp(self.post_attention_layernorm(h)) + r = self.block_sparse_moe(self.post_attention_layernorm(h)) out = h + r return out, (k_cache, v_cache) From 3a292fcd90399bb56b52ee2c022e42a1d5d8a957 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 13:39:38 +0800 Subject: [PATCH 13/25] add qk norm --- src/parallax/models/minimax.py | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index c31e79f2..ff51e702 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -21,27 +21,17 @@ def __call__( offset: int = 0, lengths: Optional[mx.array] = None, ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: - """ - Attention forward pass with explicit KV cache handling. - - Args: - x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. - mask: (batch, n_q_heads, target_len, source_len) - cache: Optional tuple (past_k, past_v). - shape: (batch, n_kv_heads, S_past_padded, head_dim) - offset: source_len_padded (scalar, used for RoPE calculation). - - Returns: - output_h: (batch, target_len, hidden_dim) - Output hidden states. - new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. - new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. - """ + batch, target_len, _ = x.shape queries = self.q_proj(x) keys = self.k_proj(x) values = self.v_proj(x) + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + queries = queries.reshape(batch, target_len, self.num_attention_heads, -1).transpose( 0, 2, 1, 3 ) From 2b7b26b900b123968033810d1515bea3317db05f Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 14:44:48 +0800 Subject: [PATCH 14/25] rebase minimax --- pyproject.toml | 8 +- src/parallax/models/minimax.py | 138 +++++++++++++++++++++++++++- src/parallax/server/http_server.py | 4 +- src/parallax/server/shard_loader.py | 4 +- src/parallax/sglang/model_runner.py | 4 +- tests/test_executor.py | 4 +- tests/test_model.py | 4 +- 7 files changed, 148 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fab731b3..1837274f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,13 +41,13 @@ parallax = "parallax.cli:main" mac = [ "torch==2.8.0", - "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git", - "mlx==0.29.3", + "mlx-lm==0.28.0", + "mlx==0.29.1", ] gpu = [ - "mlx-lm @ git+https://github.com/ml-explore/mlx-lm.git", - "mlx[cpu]==0.29.3", + "mlx-lm==0.28.0", + "mlx[cpu]==0.29.1", "sglang[all]==0.5.4.post1", ] diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index ff51e702..fdd6e1f5 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -5,10 +5,140 @@ import mlx.core as mx import mlx.nn as nn -from mlx_lm.models.base import scaled_dot_product_attention -from mlx_lm.models.minimax import MiniMaxAttention as MLXMiniMaxAttention -from mlx_lm.models.minimax import ModelArgs -from mlx_lm.models.minimax import MiniMaxDecoderLayer as MLXMiniMaxBlock +from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention +from mlx_lm.models.switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + max_position_embeddings: int + num_experts_per_tok: int + num_local_experts: int + shared_intermediate_size: int + num_hidden_layers: int + rms_norm_eps: float + rope_theta: float + rotary_dim: int + vocab_size: int + tie_word_embeddings: bool = False + scoring_func: str = "sigmoid" + head_dim: Optional[int] = None + use_qk_norm: bool = True + + +class MLXMiniMaxAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_dim = hidden_size = args.hidden_size + + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False) + self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) + self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) + self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False) + + self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) + + self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class MLXMiniMaxSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts_per_tok = args.num_experts_per_tok + + self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, args.intermediate_size, args.num_local_experts + ) + self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) + + def __call__(self, x: mx.array) -> mx.array: + gates = self.gate(x.astype(mx.float32)) + + scores = mx.sigmoid(gates) + orig_scores = scores + scores = scores + self.e_score_correction_bias + + k = self.num_experts_per_tok + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(orig_scores, inds, axis=-1) + + scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) + scores = scores.astype(x.dtype) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + return y + + +class MLXMiniMaxBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.self_attn = MLXMiniMaxAttention(args) + + self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = x + self.self_attn(self.input_layernorm(x), mask, cache) + r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) + return r class ParallaxMiniMaxAttention(MLXMiniMaxAttention): diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 42099381..955ccc4a 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -30,7 +30,7 @@ import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse from mlx_lm.tokenizer_utils import StreamingDetokenizer -from mlx_lm.utils import _download, load_config +from mlx_lm.utils import get_model_path, load_config from pydantic import BaseModel from starlette.datastructures import State @@ -105,7 +105,7 @@ def __init__( self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} # Load tokenizer for separate detokenizers - model_path = _download(model_path_str) + model_path = get_model_path(model_path_str)[0] config = load_config(model_path) self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 32627cf8..68e26055 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,7 +10,7 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.utils import _download, load_config +from mlx_lm.utils import get_model_path, load_config from parallax.server.model import ShardedModel from parallax.utils.tokenizer_utils import load_tokenizer @@ -112,7 +112,7 @@ def load( Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path = _download(self.model_path_str) + model_path = get_model_path(self.model_path_str)[0] config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 4445eaa1..e488951b 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -12,7 +12,7 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.utils import _download, load_config +from mlx_lm.utils import get_model_path, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( get_tp_group, @@ -548,7 +548,7 @@ def initialize_sgl_model_runner( - tokenizer: tokenizer driven by mlx-lm """ apply_parallax_monkey_patch() - model_path = _download(original_model_path) + model_path = get_model_path(original_model_path)[0] config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) dtype = config.get("torch_dtype", "bfloat16") diff --git a/tests/test_executor.py b/tests/test_executor.py index 8c05abca..1dd16093 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,7 +4,7 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.utils import _download, load_model +from mlx_lm.utils import get_model_path, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest @@ -12,7 +12,7 @@ MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" -model_path = _download(MODEL_REPO) +model_path = get_model_path(MODEL_REPO)[0] ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) diff --git a/tests/test_model.py b/tests/test_model.py index 45634d6d..a62dd2db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,7 +7,7 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.utils import _download, load_model +from mlx_lm.utils import get_model_path, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader @@ -18,7 +18,7 @@ TOTAL_LAYERS = 28 -model_path = _download(REPO_ID) +model_path = get_model_path(REPO_ID)[0] ref_model, ref_config = load_model(model_path) ref_tokenizer = load_tokenizer(model_path, eos_token_ids=ref_config.get("eos_token_id", None)) From 3001a5b88245448659d62037622e8abd0205c6de Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 14:46:57 +0800 Subject: [PATCH 15/25] update --- src/parallax/models/minimax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py index fdd6e1f5..42912b5d 100644 --- a/src/parallax/models/minimax.py +++ b/src/parallax/models/minimax.py @@ -1,7 +1,7 @@ # Copyright © 2025 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn From 7219ee21e4b379c555985050efb93619e8605cbc Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 14:50:18 +0800 Subject: [PATCH 16/25] update --- src/parallax/server/http_server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 955ccc4a..a7ba6499 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -168,6 +168,8 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): if is_first: role = "assistant" content = "" + if "minimax-m2" in self.model_path.lower(): + content = "" elif is_last: role = None content = None From 70831d7fa0ba7878ba32d1b7502334ed23c22a78 Mon Sep 17 00:00:00 2001 From: gufengc Date: Tue, 28 Oct 2025 15:07:52 +0800 Subject: [PATCH 17/25] update --- src/parallax/server/shard_loader.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 68e26055..3183e438 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -18,7 +18,10 @@ logger = get_logger(__name__) -MODEL_TYPE_MAP = {"kimi_k2": "deepseek_v3"} +MODEL_CLASS_MAP = { + "kimi_k2": "mlx_lm.models.deepseek_v3", + "minimax": "parallax.models.minimax", +} class MLXModelLoader: @@ -133,14 +136,16 @@ def load( # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") - if model_type in MODEL_TYPE_MAP: - model_type = MODEL_TYPE_MAP[model_type] - if not model_type: raise ValueError("model_type not found in config.json") + + if model_type in MODEL_CLASS_MAP: + model_class = MODEL_CLASS_MAP[model_type] + else: + model_class = f"mlx_lm.models.{model_type}" + try: - # Import from project's models directory - arch_module = importlib.import_module(f"parallax.models.{model_type}") + arch_module = importlib.import_module(model_class) model_args_class = getattr(arch_module, "ModelArgs") model_args = model_args_class.from_dict(config) From a6e09aa4180225d675b108d317b5e49a4bec4697 Mon Sep 17 00:00:00 2001 From: gufengc Date: Tue, 28 Oct 2025 15:08:39 +0800 Subject: [PATCH 18/25] update --- src/parallax/server/shard_loader.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 3183e438..d9c5a75b 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -84,22 +84,6 @@ def register_block_class(self): except Exception as e: logger.warning(f"Failed to load model from {model_file}: {e}") - def _create_default_model_args(self, config: Dict[str, Any]) -> Any: - """Create default model arguments from config.""" - model_args = { - "hidden_size": config.get("hidden_size", 0), - "num_attention_heads": config.get("num_attention_heads", 0), - "num_key_value_heads": config.get("num_key_value_heads", 0), - "num_hidden_layers": config.get("num_hidden_layers", 0), - "intermediate_size": config.get("intermediate_size", 0), - "vocab_size": config.get("vocab_size", 0), - "head_dim": config.get("head_dim", 128), - "num_local_experts": config.get("num_local_experts", None), - "num_experts_per_tok": config.get("num_experts_per_tok", None), - "moe_intermediate_size": config.get("moe_intermediate_size", None), - } - return type("ModelArgs", (), model_args)() - def load( self, lazy: bool = False, strict: bool = True ) -> Tuple[nn.Module, Dict[str, Any], Any]: From 1f64a2db0d57284a20396b815fe18b70b413e0df Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 15:19:14 +0800 Subject: [PATCH 19/25] hack --- src/backend/server/static_config.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 7cef11c9..0130f9c7 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -87,6 +87,10 @@ def _load_config_only(name: str) -> dict: elif quant_method in ("mxfp4", "int4", "awq", "gptq"): param_bytes_per_element = 0.5 + ## Only for hack, fix it when support different quantization bits + if "minimax_m2" in model_name.lower(): + param_bytes_per_element = 0.5 # MiniMax M2 uses FP16 weights + # get local experts num_local_experts = config.get("num_local_experts", None) if num_local_experts is None: From db957f2e28d0b5bc7aebfe3dfade2453523458ff Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 15:34:05 +0800 Subject: [PATCH 20/25] update param name --- src/backend/server/static_config.py | 3 ++- src/parallax/server/http_server.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 0130f9c7..b6e3d220 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -88,7 +88,8 @@ def _load_config_only(name: str) -> dict: param_bytes_per_element = 0.5 ## Only for hack, fix it when support different quantization bits - if "minimax_m2" in model_name.lower(): + print("Model:", model_name, "Quantization method:", quant_method) + if "minimax-m2" in model_name.lower(): param_bytes_per_element = 0.5 # MiniMax M2 uses FP16 weights # get local experts diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index a7ba6499..5da1694d 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -107,6 +107,7 @@ def __init__( # Load tokenizer for separate detokenizers model_path = get_model_path(model_path_str)[0] config = load_config(model_path) + self.model_path = model_path self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) From 54d4dc0caec7c8a7b26ea19377f21d72e8eaae34 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 15:41:04 +0800 Subject: [PATCH 21/25] update --- src/parallax/server/http_server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 5da1694d..fd815688 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -107,7 +107,7 @@ def __init__( # Load tokenizer for separate detokenizers model_path = get_model_path(model_path_str)[0] config = load_config(model_path) - self.model_path = model_path + self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) @@ -169,7 +169,7 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): if is_first: role = "assistant" content = "" - if "minimax-m2" in self.model_path.lower(): + if "minimax-m2" in self.model_path_str.lower(): content = "" elif is_last: role = None From efe0f33f6541434fd8c80df4dc1ae777463315a4 Mon Sep 17 00:00:00 2001 From: yuhao-zh Date: Tue, 28 Oct 2025 09:01:01 +0000 Subject: [PATCH 22/25] updata gpu load --- .../sglang/monkey_patch/minimax_m2_model.py | 90 +++++++------------ 1 file changed, 34 insertions(+), 56 deletions(-) diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch/minimax_m2_model.py index 6a42ec5b..06485577 100644 --- a/src/parallax/sglang/monkey_patch/minimax_m2_model.py +++ b/src/parallax/sglang/monkey_patch/minimax_m2_model.py @@ -11,12 +11,13 @@ maybe_remap_kv_scale_name, ) from sglang.srt.models.minimax_m2 import get_spec_layer_idx_from_weight_name - +logger = logging.getLogger(__name__) def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """Load model weights with proper mapping for MiniMax architecture.""" stacked_params_mapping = [ + # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), @@ -24,6 +25,8 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) ("gate_up_proj", "up_proj", 1), ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", @@ -32,44 +35,14 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) ) params_dict = dict(self.named_parameters()) - logger = logging.getLogger(__name__) - - weight_name_map = { - "lm_head.weight": "model.embed_tokens.weight", - } - - def resolve_param(name: str): - """Resolve weight name to actual parameter, handling tied weights and PP filtering.""" - if name in weight_name_map: - mapped_name = weight_name_map[name] - if mapped_name in params_dict: - logger.debug("Mapped '%s' -> '%s' (tied weight)", name, mapped_name) - return mapped_name, params_dict[mapped_name] - - if name in params_dict: - return name, params_dict[name] - - alt = f"model.{name}" - if alt in params_dict: - return alt, params_dict[alt] - - matches = [k for k in params_dict.keys() if k.endswith(name)] - if len(matches) == 1: - return matches[0], params_dict[matches[0]] - - if name in ("model.norm.weight", "model.embed_tokens.weight"): - logger.debug("Weight '%s' not found (PP-sliced)", name) - return None, None - - if ("lm_head" in name) or ("embed" in name): - sample = [k for k in params_dict.keys() if ("lm_head" in k) or ("embed" in k)] - if not sample: - sample = list(params_dict.keys())[:50] - logger.warning("Failed to resolve '%s'. Sample params: %s", name, sample) - return None, None - loaded_params: Set[str] = set() for name, loaded_weight in weights: + if "lm_head" in name: + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if not pp_group.is_last_rank: + logger.debug("Skipping lm_head weight '%s' on non-last PP rank", name) + continue + layer_id = get_layer_id(name) if ( layer_id is not None @@ -77,28 +50,31 @@ def resolve_param(name: str): and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) ): continue - if "rotary_emb.inv_freq" in name: continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: - continue + continue # skip spec decode layers for main model for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. if ("mlp.experts." in name) and name not in params_dict: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - resolved_name, param = resolve_param(name) - if param is None: - if name not in ("model.norm.weight", "model.embed_tokens.weight"): - logger.warning("Skipping weight '%s' (no matching parameter)", name) - continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break @@ -109,28 +85,30 @@ def resolve_param(name: str): continue name = name.replace(weight_name, param_name) - resolved_name, param = resolve_param(name) - if param is None: - if name not in ("model.norm.weight", "model.embed_tokens.weight"): - logger.warning("Skipping expert weight '%s' (no matching parameter)", name) - continue + param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, name, shard_id=shard_id, expert_id=expert_id) + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) break else: + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue - resolved_name, param = resolve_param(name) - if param is None: - if name not in ("model.norm.weight", "model.embed_tokens.weight"): - logger.warning("Skipping weight '%s' (no matching parameter)", name) - continue - weight_loader = getattr(param, "weight_loader", default_weight_loader) + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params From 353b60304e925f72caac81db185e7e7584e9a3fa Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 17:02:55 +0800 Subject: [PATCH 23/25] pre-commit --- src/parallax/server/shard_loader.py | 2 +- src/parallax/sglang/monkey_patch/minimax_m2_model.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index d9c5a75b..be03e00f 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -122,7 +122,7 @@ def load( model_type = config.get("model_type") if not model_type: raise ValueError("model_type not found in config.json") - + if model_type in MODEL_CLASS_MAP: model_class = MODEL_CLASS_MAP[model_type] else: diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch/minimax_m2_model.py index 06485577..f14e59bf 100644 --- a/src/parallax/sglang/monkey_patch/minimax_m2_model.py +++ b/src/parallax/sglang/monkey_patch/minimax_m2_model.py @@ -11,8 +11,10 @@ maybe_remap_kv_scale_name, ) from sglang.srt.models.minimax_m2 import get_spec_layer_idx_from_weight_name + logger = logging.getLogger(__name__) + def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """Load model weights with proper mapping for MiniMax architecture.""" @@ -42,7 +44,7 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) if not pp_group.is_last_rank: logger.debug("Skipping lm_head weight '%s' on non-last PP rank", name) continue - + layer_id = get_layer_id(name) if ( layer_id is not None @@ -106,9 +108,7 @@ def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) continue param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params From 7175bc7c2e7dbe82b10aa6292e96a0a3cc77e345 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 17:07:07 +0800 Subject: [PATCH 24/25] update --- src/backend/server/static_config.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index b6e3d220..9b063f11 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -87,10 +87,9 @@ def _load_config_only(name: str) -> dict: elif quant_method in ("mxfp4", "int4", "awq", "gptq"): param_bytes_per_element = 0.5 - ## Only for hack, fix it when support different quantization bits - print("Model:", model_name, "Quantization method:", quant_method) + # Only for hack, fix it when support different quantization bits if "minimax-m2" in model_name.lower(): - param_bytes_per_element = 0.5 # MiniMax M2 uses FP16 weights + param_bytes_per_element = 0.5 # get local experts num_local_experts = config.get("num_local_experts", None) From b9da67ac2d67a7035aafed8a586a8172d44d8802 Mon Sep 17 00:00:00 2001 From: yuhao_zhang Date: Tue, 28 Oct 2025 17:12:51 +0800 Subject: [PATCH 25/25] rm hack --- src/backend/server/static_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 9b063f11..39dd184a 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -88,8 +88,8 @@ def _load_config_only(name: str) -> dict: param_bytes_per_element = 0.5 # Only for hack, fix it when support different quantization bits - if "minimax-m2" in model_name.lower(): - param_bytes_per_element = 0.5 + # if "minimax-m2" in model_name.lower(): + # param_bytes_per_element = 0.5 # get local experts num_local_experts = config.get("num_local_experts", None)