diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 1d4f3a8e..39dd184a 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -49,6 +49,7 @@ "deepseek-ai/DeepSeek-R1", "deepseek-ai/DeepSeek-V3", "deepseek-ai/DeepSeek-V2", + "MiniMaxAI/MiniMax-M2", ] NODE_JOIN_COMMAND_LOCAL_NETWORK = """parallax join""" @@ -86,6 +87,10 @@ def _load_config_only(name: str) -> dict: elif quant_method in ("mxfp4", "int4", "awq", "gptq"): param_bytes_per_element = 0.5 + # Only for hack, fix it when support different quantization bits + # if "minimax-m2" in model_name.lower(): + # param_bytes_per_element = 0.5 + # get local experts num_local_experts = config.get("num_local_experts", None) if num_local_experts is None: diff --git a/src/parallax/launch.py b/src/parallax/launch.py index 93e18872..06d437b6 100644 --- a/src/parallax/launch.py +++ b/src/parallax/launch.py @@ -41,6 +41,7 @@ "Qwen/Qwen3-235B-A22B-GPTQ-Int4": "mlx-community/Qwen3-235B-A22B-4bit", "moonshotai/Kimi-K2-Instruct": "mlx-community/Kimi-K2-Instruct-4bit", "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx", + "MiniMaxAI/MiniMax-M2": "mlx-community/MiniMax-M2-4bit", } if __name__ == "__main__": diff --git a/src/parallax/models/minimax.py b/src/parallax/models/minimax.py new file mode 100644 index 00000000..42912b5d --- /dev/null +++ b/src/parallax/models/minimax.py @@ -0,0 +1,235 @@ +# Copyright © 2025 Apple Inc. + +from dataclasses import dataclass +from typing import Any, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention +from mlx_lm.models.switch_layers import SwitchGLU + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_key_value_heads: int + max_position_embeddings: int + num_experts_per_tok: int + num_local_experts: int + shared_intermediate_size: int + num_hidden_layers: int + rms_norm_eps: float + rope_theta: float + rotary_dim: int + vocab_size: int + tie_word_embeddings: bool = False + scoring_func: str = "sigmoid" + head_dim: Optional[int] = None + use_qk_norm: bool = True + + +class MLXMiniMaxAttention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.hidden_dim = hidden_size = args.hidden_size + + self.num_attention_heads = args.num_attention_heads + self.num_key_value_heads = args.num_key_value_heads + self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False) + self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) + self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) + self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False) + + self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False + if self.use_qk_norm: + self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) + self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) + + self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.o_proj(output) + + +class MLXMiniMaxSparseMoeBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_experts_per_tok = args.num_experts_per_tok + + self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) + self.switch_mlp = SwitchGLU( + args.hidden_size, args.intermediate_size, args.num_local_experts + ) + self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) + + def __call__(self, x: mx.array) -> mx.array: + gates = self.gate(x.astype(mx.float32)) + + scores = mx.sigmoid(gates) + orig_scores = scores + scores = scores + self.e_score_correction_bias + + k = self.num_experts_per_tok + inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] + scores = mx.take_along_axis(orig_scores, inds, axis=-1) + + scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) + scores = scores.astype(x.dtype) + + y = self.switch_mlp(x, inds) + y = (y * scores[..., None]).sum(axis=-2) + return y + + +class MLXMiniMaxBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + self.self_attn = MLXMiniMaxAttention(args) + + self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args) + + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = x + self.self_attn(self.input_layernorm(x), mask, cache) + r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) + return r + + +class ParallaxMiniMaxAttention(MLXMiniMaxAttention): + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + + batch, target_len, _ = x.shape + + queries = self.q_proj(x) + keys = self.k_proj(x) + values = self.v_proj(x) + + if self.use_qk_norm: + queries = self.q_norm(queries) + keys = self.k_norm(keys) + + queries = queries.reshape(batch, target_len, self.num_attention_heads, -1).transpose( + 0, 2, 1, 3 + ) + keys = keys.reshape(batch, target_len, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(batch, target_len, self.num_key_value_heads, -1).transpose( + 0, 2, 1, 3 + ) + + # for batch, rope offset is not correct due to padding in batch + queries_rotated = self.rope(queries, offset=offset) + keys_rotated = self.rope(keys, offset=offset) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = keys_rotated + final_values_for_attn = values + + output = scaled_dot_product_attention( + queries_rotated, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) + return self.o_proj(output), (keys_rotated, values) + + +class ParallaxMiniMaxBlock(MLXMiniMaxBlock): + """A custom transformer block for Parallax, extending the MiniMax Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args) + self.self_attn = ParallaxMiniMaxAttention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.block_sparse_moe(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "MiniMaxM2ForCausalLM" + + +EntryClass = ParallaxMiniMaxBlock diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 50320c2c..fd815688 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -105,8 +105,9 @@ def __init__( self.recv_from_executor = get_zmq_socket(context, zmq.PULL, executor_output_ipc_name, True) self.processing_requests: Dict[str, HTTPRequestInfo] = {} # Load tokenizer for separate detokenizers - model_path, _ = get_model_path(model_path_str) + model_path = get_model_path(model_path_str)[0] config = load_config(model_path) + self.model_path_str = model_path_str self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer) @@ -168,6 +169,8 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False): if is_first: role = "assistant" content = "" + if "minimax-m2" in self.model_path_str.lower(): + content = "" elif is_last: role = None content = None diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index f15250d6..be03e00f 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -18,6 +18,11 @@ logger = get_logger(__name__) +MODEL_CLASS_MAP = { + "kimi_k2": "mlx_lm.models.deepseek_v3", + "minimax": "parallax.models.minimax", +} + class MLXModelLoader: """ @@ -94,7 +99,7 @@ def load( Returns: A tuple containing the loaded sharded MLX model and its configuration dictionary. """ - model_path, _ = get_model_path(self.model_path_str) + model_path = get_model_path(self.model_path_str)[0] config = load_config(model_path) tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None)) @@ -115,14 +120,19 @@ def load( # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") - if model_type == "kimi_k2": - model_type = "deepseek_v3" if not model_type: raise ValueError("model_type not found in config.json") + + if model_type in MODEL_CLASS_MAP: + model_class = MODEL_CLASS_MAP[model_type] + else: + model_class = f"mlx_lm.models.{model_type}" + try: - arch_module = importlib.import_module(f"mlx_lm.models.{model_type}") + arch_module = importlib.import_module(model_class) model_args_class = getattr(arch_module, "ModelArgs") model_args = model_args_class.from_dict(config) + except (ImportError, AttributeError) as e: raise ValueError(f"Failed to load architecture for model_type '{model_type}'.") from e diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index 37c075aa..e488951b 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -488,6 +488,14 @@ def monkey_patch_triton_backend_init(): apply_triton_backend_init_monkey_patch() +def monkey_patch_minimax_m2_model(): + from parallax.sglang.monkey_patch.minimax_m2_model import ( + apply_minimax_m2_monkey_patch, + ) + + apply_minimax_m2_monkey_patch() + + def form_sgl_server_args( model_path: str, dtype: str = "bfloat16", @@ -520,6 +528,7 @@ def apply_parallax_monkey_patch(): monkey_patch_qwen3_next() monkey_patch_gpt_oss() monkey_patch_triton_backend_init() + monkey_patch_minimax_m2_model() def initialize_sgl_model_runner( diff --git a/src/parallax/sglang/monkey_patch/minimax_m2_model.py b/src/parallax/sglang/monkey_patch/minimax_m2_model.py new file mode 100644 index 00000000..f14e59bf --- /dev/null +++ b/src/parallax/sglang/monkey_patch/minimax_m2_model.py @@ -0,0 +1,151 @@ +import logging +from typing import Iterable, Optional, Set, Tuple + +import torch +from sglang.srt.distributed import get_pp_group +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.utils import get_layer_id +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.models.minimax_m2 import get_spec_layer_idx_from_weight_name + +logger = logging.getLogger(__name__) + + +def monkey_patch_load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load model weights with proper mapping for MiniMax architecture.""" + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "lm_head" in name: + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if not pp_group.is_last_rank: + logger.debug("Skipping lm_head weight '%s' on non-last PP rank", name) + continue + + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and (layer_id < self.model.start_layer or layer_id >= self.model.end_layer) + ): + continue + if "rotary_emb.inv_freq" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def apply_minimax_m2_monkey_patch(): + """Apply monkey patches to MiniMax M2 for PP support and weight loading.""" + import sglang.srt.models.minimax_m2 as m2_module + + orig_init = m2_module.MiniMaxM2ForCausalLM.__init__ + + def pp_init(self, config, quant_config=None, prefix=""): + orig_init(self, config, quant_config, prefix) + self.pp_group = get_pp_group() + + def pp_forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch, + inputs_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + **kwargs, + ): + hidden_states = self.model( + input_ids, positions, forward_batch, inputs_embeds, pp_proxy_tensors + ) + + if isinstance(hidden_states, PPProxyTensors): + return hidden_states + + pp_group = getattr(self, "pp_group", None) or get_pp_group() + if pp_group.is_last_rank: + return self.logits_processor(input_ids, hidden_states, self.lm_head, forward_batch) + else: + return hidden_states + + m2_module.MiniMaxM2ForCausalLM.__init__ = pp_init + m2_module.MiniMaxM2ForCausalLM.forward = pp_forward + m2_module.MiniMaxM2ForCausalLM.load_weights = monkey_patch_load_weights