-
Notifications
You must be signed in to change notification settings - Fork 62
feat(model): add minimax m2 inplace without updata mlx-lm #161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
yuhao-zh
merged 25 commits into
GradientHQ:main
from
yuhao-zh:temp/add_minimax_m2_inplace
Oct 28, 2025
Merged
Changes from all commits
Commits
Show all changes
25 commits
Select commit
Hold shift + click to select a range
6dd4499
add gpu support
0f6d245
update mlx support
yuhao-zh 53b5aad
update
yuhao-zh cfeab39
update
yuhao-zh c59dadb
update
yuhao-zh c7a2285
update model_type Map
yuhao-zh 6e63edf
update model
yuhao-zh 1569fa9
update mlx-lm
yuhao-zh 79771fa
update
yuhao-zh 947f8bc
add model map
yuhao-zh fc13b13
update params name
yuhao-zh c0d4f9b
update name
yuhao-zh 3a292fc
add qk norm
yuhao-zh 2b7b26b
rebase minimax
yuhao-zh 3001a5b
update
yuhao-zh 7219ee2
update
yuhao-zh 70831d7
update
gufengc a6e09aa
update
gufengc 1f64a2d
hack
yuhao-zh db957f2
update param name
yuhao-zh 54d4dc0
update
yuhao-zh efe0f33
updata gpu load
353b603
pre-commit
yuhao-zh 7175bc7
update
yuhao-zh b9da67a
rm hack
yuhao-zh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,235 @@ | ||
| # Copyright Β© 2025 Apple Inc. | ||
|
|
||
| from dataclasses import dataclass | ||
| from typing import Any, Optional, Tuple | ||
|
|
||
| import mlx.core as mx | ||
| import mlx.nn as nn | ||
| from mlx_lm.models.base import BaseModelArgs, scaled_dot_product_attention | ||
| from mlx_lm.models.switch_layers import SwitchGLU | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelArgs(BaseModelArgs): | ||
| model_type: str | ||
| hidden_size: int | ||
| intermediate_size: int | ||
| num_attention_heads: int | ||
| num_key_value_heads: int | ||
| max_position_embeddings: int | ||
| num_experts_per_tok: int | ||
| num_local_experts: int | ||
| shared_intermediate_size: int | ||
| num_hidden_layers: int | ||
| rms_norm_eps: float | ||
| rope_theta: float | ||
| rotary_dim: int | ||
| vocab_size: int | ||
| tie_word_embeddings: bool = False | ||
| scoring_func: str = "sigmoid" | ||
| head_dim: Optional[int] = None | ||
| use_qk_norm: bool = True | ||
|
|
||
|
|
||
| class MLXMiniMaxAttention(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
|
|
||
| self.hidden_dim = hidden_size = args.hidden_size | ||
|
|
||
| self.num_attention_heads = args.num_attention_heads | ||
| self.num_key_value_heads = args.num_key_value_heads | ||
| self.head_dim = head_dim = args.head_dim or hidden_size // args.num_attention_heads | ||
| self.scale = head_dim**-0.5 | ||
|
|
||
| self.q_proj = nn.Linear(args.hidden_size, self.num_attention_heads * head_dim, bias=False) | ||
| self.k_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) | ||
| self.v_proj = nn.Linear(args.hidden_size, self.num_key_value_heads * head_dim, bias=False) | ||
| self.o_proj = nn.Linear(self.num_attention_heads * head_dim, args.hidden_size, bias=False) | ||
|
|
||
| self.use_qk_norm = args.use_qk_norm if hasattr(args, "use_qk_norm") else False | ||
| if self.use_qk_norm: | ||
| self.q_norm = nn.RMSNorm(head_dim * self.num_attention_heads, eps=args.rms_norm_eps) | ||
| self.k_norm = nn.RMSNorm(head_dim * self.num_key_value_heads, eps=args.rms_norm_eps) | ||
|
|
||
| self.rope = nn.RoPE(args.rotary_dim, traditional=False, base=args.rope_theta) | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: mx.array, | ||
| mask: Optional[mx.array] = None, | ||
| cache: Optional[Any] = None, | ||
| ) -> mx.array: | ||
| B, L, D = x.shape | ||
|
|
||
| queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) | ||
|
|
||
| if self.use_qk_norm: | ||
| queries = self.q_norm(queries) | ||
| keys = self.k_norm(keys) | ||
|
|
||
| queries = queries.reshape(B, L, self.num_attention_heads, -1).transpose(0, 2, 1, 3) | ||
| keys = keys.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) | ||
| values = values.reshape(B, L, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) | ||
|
|
||
| if cache is not None: | ||
| queries = self.rope(queries, offset=cache.offset) | ||
| keys = self.rope(keys, offset=cache.offset) | ||
| keys, values = cache.update_and_fetch(keys, values) | ||
| else: | ||
| queries = self.rope(queries) | ||
| keys = self.rope(keys) | ||
|
|
||
| output = scaled_dot_product_attention( | ||
| queries, keys, values, cache=cache, scale=self.scale, mask=mask | ||
| ) | ||
|
|
||
| output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
|
|
||
| return self.o_proj(output) | ||
|
|
||
|
|
||
| class MLXMiniMaxSparseMoeBlock(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
| self.num_experts_per_tok = args.num_experts_per_tok | ||
|
|
||
| self.gate = nn.Linear(args.hidden_size, args.num_local_experts, bias=False) | ||
| self.switch_mlp = SwitchGLU( | ||
| args.hidden_size, args.intermediate_size, args.num_local_experts | ||
| ) | ||
| self.e_score_correction_bias = mx.zeros((args.num_local_experts,)) | ||
|
|
||
| def __call__(self, x: mx.array) -> mx.array: | ||
| gates = self.gate(x.astype(mx.float32)) | ||
|
|
||
| scores = mx.sigmoid(gates) | ||
| orig_scores = scores | ||
| scores = scores + self.e_score_correction_bias | ||
|
|
||
| k = self.num_experts_per_tok | ||
| inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k] | ||
| scores = mx.take_along_axis(orig_scores, inds, axis=-1) | ||
|
|
||
| scores = scores / (mx.sum(scores, axis=-1, keepdims=True) + 1e-20) | ||
| scores = scores.astype(x.dtype) | ||
|
|
||
| y = self.switch_mlp(x, inds) | ||
| y = (y * scores[..., None]).sum(axis=-2) | ||
| return y | ||
|
|
||
|
|
||
| class MLXMiniMaxBlock(nn.Module): | ||
| def __init__(self, args: ModelArgs): | ||
| super().__init__() | ||
|
|
||
| self.self_attn = MLXMiniMaxAttention(args) | ||
|
|
||
| self.block_sparse_moe = MLXMiniMaxSparseMoeBlock(args) | ||
|
|
||
| self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
| self.post_attention_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: mx.array, | ||
| mask: Optional[mx.array] = None, | ||
| cache: Optional[Any] = None, | ||
| ) -> mx.array: | ||
| r = x + self.self_attn(self.input_layernorm(x), mask, cache) | ||
| r = r + self.block_sparse_moe(self.post_attention_layernorm(r)) | ||
| return r | ||
|
|
||
|
|
||
| class ParallaxMiniMaxAttention(MLXMiniMaxAttention): | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: mx.array, | ||
| mask: Optional[mx.array] = None, | ||
| cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
| offset: int = 0, | ||
| lengths: Optional[mx.array] = None, | ||
| ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: | ||
|
|
||
| batch, target_len, _ = x.shape | ||
|
|
||
| queries = self.q_proj(x) | ||
| keys = self.k_proj(x) | ||
| values = self.v_proj(x) | ||
|
|
||
| if self.use_qk_norm: | ||
| queries = self.q_norm(queries) | ||
| keys = self.k_norm(keys) | ||
|
|
||
| queries = queries.reshape(batch, target_len, self.num_attention_heads, -1).transpose( | ||
| 0, 2, 1, 3 | ||
| ) | ||
| keys = keys.reshape(batch, target_len, self.num_key_value_heads, -1).transpose(0, 2, 1, 3) | ||
| values = values.reshape(batch, target_len, self.num_key_value_heads, -1).transpose( | ||
| 0, 2, 1, 3 | ||
| ) | ||
|
|
||
| # for batch, rope offset is not correct due to padding in batch | ||
| queries_rotated = self.rope(queries, offset=offset) | ||
| keys_rotated = self.rope(keys, offset=offset) | ||
|
|
||
| if cache is not None: | ||
| past_k, past_v = cache | ||
| if past_k is not None and past_v is not None: | ||
| if past_k.shape[2] != offset: | ||
| raise ValueError( | ||
| f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " | ||
| f"to match RoPE offset {offset} (S_past_padded)." | ||
| ) | ||
| final_keys_for_attn = mx.concatenate([past_k, keys_rotated], axis=2) | ||
| final_values_for_attn = mx.concatenate([past_v, values], axis=2) | ||
| else: | ||
| raise ValueError("cache was provided but one of k/v was None.") | ||
| else: | ||
| final_keys_for_attn = keys_rotated | ||
| final_values_for_attn = values | ||
|
|
||
| output = scaled_dot_product_attention( | ||
| queries_rotated, | ||
| final_keys_for_attn, | ||
| final_values_for_attn, | ||
| scale=self.scale, | ||
| mask=mask, | ||
| cache=None, | ||
| ) | ||
|
|
||
| output = output.transpose(0, 2, 1, 3).reshape(batch, target_len, -1) | ||
| return self.o_proj(output), (keys_rotated, values) | ||
|
|
||
|
|
||
| class ParallaxMiniMaxBlock(MLXMiniMaxBlock): | ||
| """A custom transformer block for Parallax, extending the MiniMax Block class. | ||
| This version handles the KV cache explicitly and returns new K and V states. | ||
| """ | ||
|
|
||
| def __init__(self, args: ModelArgs, layer_idx: int): | ||
| super().__init__(args) | ||
| self.self_attn = ParallaxMiniMaxAttention(args) | ||
|
|
||
| def __call__( | ||
| self, | ||
| x: mx.array, | ||
| mask: Optional[mx.array] = None, | ||
| cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
| offset: int = 0, | ||
| lengths: Optional[mx.array] = None, | ||
| ): | ||
| r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) | ||
| h = x + r | ||
| r = self.block_sparse_moe(self.post_attention_layernorm(h)) | ||
| out = h + r | ||
| return out, (k_cache, v_cache) | ||
|
|
||
| @classmethod | ||
| def get_architecture(cls): | ||
| """Get the architecture name for the block.""" | ||
| return "MiniMaxM2ForCausalLM" | ||
|
|
||
|
|
||
| EntryClass = ParallaxMiniMaxBlock |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.