diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index 85427cab..004cf3b9 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -7,6 +7,12 @@ # Supported model list MODEL_LIST = [ "Qwen/Qwen3-0.6B", + "openai/gpt-oss-20b", + "openai/gpt-oss-120b", + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Instruct-0905", + "Qwen/Qwen3-Next-80B-A3B-Instruct", + "Qwen/Qwen3-Next-80B-A3B-Thinking", # "Qwen/Qwen3-8B", # "Qwen/Qwen3-8B-FP8", "Qwen/Qwen3-32B", @@ -16,14 +22,10 @@ # "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8", "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8", "Qwen/Qwen3-235B-A22B-Thinking-2507-FP8", - "Qwen/Qwen3-Next-80B-A3B-Instruct", - "Qwen/Qwen3-Next-80B-A3B-Thinking", # "Qwen/Qwen2.5-3B-Instruct", # "Qwen/Qwen2.5-7B-Instruct", # "Qwen/Qwen2.5-14B-Instruct", "Qwen/Qwen2.5-72B-Instruct", - "openai/gpt-oss-20b", - "openai/gpt-oss-120b", "nvidia/Llama-3.3-70B-Instruct-FP8", "nvidia/Llama-3.1-70B-Instruct-FP8", "nvidia/Llama-3.1-8B-Instruct-FP8", @@ -56,6 +58,8 @@ def get_model_info(model_name): model_info = ModelInfo( model_name=model_name, head_size=config.get("head_dim", 128), + qk_nope_head_dim=config.get("qk_nope_head_dim", None), + qk_rope_head_dim=config.get("qk_rope_head_dim", None), hidden_dim=config.get("hidden_size", 0), intermediate_dim=config.get("intermediate_size", 0), num_attention_heads=config.get("num_attention_heads", 0), diff --git a/src/parallax/models/deepseek_v2.py b/src/parallax/models/deepseek_v2.py new file mode 100644 index 00000000..a9bd2254 --- /dev/null +++ b/src/parallax/models/deepseek_v2.py @@ -0,0 +1,130 @@ +""" +hidden_dimefines the Qwen3 model. +""" + +from typing import Optional, Tuple + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.deepseek_v2 import DeepseekV2Attention as MLXDeepseekV2Attention +from mlx_lm.models.deepseek_v2 import DeepseekV2DecoderLayer as MLXDeepseekV2Block +from mlx_lm.models.deepseek_v2 import ModelArgs + + +class ParallaxDeepSeekV2Attention(MLXDeepseekV2Attention): + """A custom attention module for Parallax, extending the DeepseekV2 Attention class. + + We apply explicit KV cache handling and passing in `offset` directly from Request. + This version returns the new K and V states for external caching. + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + B, L, D = x.shape + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + q_pe = self.rope(q_pe, offset=offset) + k_pe = self.rope(k_pe, offset=offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate( + [past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2 + ) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) + final_values_for_attn = values + + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) # Ensure mask is the same dtype as queries + output = scaled_dot_product_attention( + queries, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + + # print(f"Output values shape: {values.shape}") + # print(f"Output k_nope shape: {(mx.concatenate([k_nope, k_pe], axis=-1)).shape}") + return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) + + +class ParallaxDeepSeekV2Block(MLXDeepseekV2Block): + """A custom transformer block for Parallax, extending the Qwen3 Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args, layer_idx=layer_idx) + self.self_attn = ParallaxDeepSeekV2Attention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "DeepseekV2ForCausalLM" + + +EntryClass = ParallaxDeepSeekV2Block diff --git a/src/parallax/models/kimi_k2.py b/src/parallax/models/kimi_k2.py new file mode 100644 index 00000000..d3b38043 --- /dev/null +++ b/src/parallax/models/kimi_k2.py @@ -0,0 +1,127 @@ +""" +hidden_dimefines the Qwen3 model. +""" + +from typing import Optional, Tuple + +import mlx.core as mx +from mlx_lm.models.base import scaled_dot_product_attention +from mlx_lm.models.deepseek_v3 import DeepseekV3Attention as MLXDeepseekV3Attention +from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseekV3Block +from mlx_lm.models.deepseek_v3 import ModelArgs + + +class ParallaxKimiK2Attention(MLXDeepseekV3Attention): + """A custom attention module for Parallax, extending the DeepseekV3 Attention class. + + We apply explicit KV cache handling and passing in `offset` directly from Request. + This version returns the new K and V states for external caching. + """ + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + """ + Attention forward pass with explicit KV cache handling. + + Args: + x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment. + mask: (batch, n_q_heads, target_len, source_len) + cache: Optional tuple (past_k, past_v). + shape: (batch, n_kv_heads, S_past_padded, head_dim) + offset: source_len_padded (scalar, used for RoPE calculation). + + Returns: + output_h: (batch, target_len, hidden_dim) - Output hidden states. + new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment. + new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment. + """ + B, L, D = x.shape + if self.q_lora_rank is None: + q = self.q_proj(x) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x))) + + q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3) + q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1) + compressed_kv = self.kv_a_proj_with_mqa(x) + compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1) + k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3) + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + + k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1) + + q_pe = self.rope(q_pe, offset=offset) + k_pe = self.rope(k_pe, offset=offset) + k_pe = mx.repeat(k_pe, self.num_heads, axis=1) + queries = mx.concatenate([q_nope, q_pe], axis=-1) + + if cache is not None: + past_k, past_v = cache + if past_k is not None and past_v is not None: + if past_k.shape[2] != offset: + raise ValueError( + f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} " + f"to match RoPE offset {offset} (S_past_padded)." + ) + final_keys_for_attn = mx.concatenate( + [past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2 + ) + final_values_for_attn = mx.concatenate([past_v, values], axis=2) + else: + raise ValueError("cache was provided but one of k/v was None.") + else: + final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1) + final_values_for_attn = values + + if mask is not None: + mask = mx.array(mask, dtype=queries.dtype) + output = scaled_dot_product_attention( + queries, + final_keys_for_attn, + final_values_for_attn, + scale=self.scale, + mask=mask, + cache=None, + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values) + + +class ParallaxKimiK2Block(MLXDeepseekV3Block): + """A custom transformer block for Parallax, extending the Qwen3 Block class. + This version handles the KV cache explicitly and returns new K and V states. + """ + + def __init__(self, args: ModelArgs, layer_idx: int): + super().__init__(args, layer_idx=layer_idx) + self.self_attn = ParallaxKimiK2Attention(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + offset: int = 0, + lengths: Optional[mx.array] = None, + ): + r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, (k_cache, v_cache) + + @classmethod + def get_architecture(cls): + """Get the architecture name for the block.""" + return "DeepseekV3ForCausalLM" + + +EntryClass = ParallaxKimiK2Block diff --git a/src/parallax/server/executor.py b/src/parallax/server/executor.py index e6638d0d..0c72e3dd 100644 --- a/src/parallax/server/executor.py +++ b/src/parallax/server/executor.py @@ -160,6 +160,8 @@ def __init__( self.head_dim = self.config.get("head_dim") or self.config.get( "hidden_size" ) // self.config.get("num_attention_heads") + self.qk_nope_head_dim = self.config.get("qk_nope_head_dim", None) + self.qk_rope_head_dim = self.config.get("qk_rope_head_dim", None) self.enable_prefix_cache = enable_prefix_cache self.linear_key_head_dim = self.config.get("linear_key_head_dim", None) self.linear_value_head_dim = self.config.get("linear_value_head_dim", None) @@ -209,6 +211,8 @@ def __init__( linear_v_dim=self.linear_value_head_dim, linear_num_k_heads=self.linear_num_key_heads, linear_num_v_heads=self.linear_num_value_heads, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, max_num_tokens=max_tokens_in_kv_pool, ) mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"]) diff --git a/src/parallax/server/http_server.py b/src/parallax/server/http_server.py index 39363730..bfcf15c5 100644 --- a/src/parallax/server/http_server.py +++ b/src/parallax/server/http_server.py @@ -29,12 +29,12 @@ import zmq import zmq.asyncio from fastapi.responses import ORJSONResponse, StreamingResponse -from mlx_lm.tokenizer_utils import StreamingDetokenizer, load_tokenizer +from mlx_lm.tokenizer_utils import StreamingDetokenizer from mlx_lm.utils import get_model_path, load_config from pydantic import BaseModel from starlette.datastructures import State -from parallax.utils.tokenizer_utils import load_detokenizer +from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer from parallax.utils.utils import get_zmq_socket from parallax_utils.logging_config import get_logger diff --git a/src/parallax/server/kv_cache.py b/src/parallax/server/kv_cache.py index da19c337..8a2e5ffa 100644 --- a/src/parallax/server/kv_cache.py +++ b/src/parallax/server/kv_cache.py @@ -37,7 +37,8 @@ class KVCache: def __init__( self, num_kv_heads: int, - head_dim: int, + head_dim_k: int, + head_dim_v: int, num_layers: int, dtype: mx.Dtype, block_size: int = 64, @@ -47,6 +48,8 @@ def __init__( linear_v_dim: Optional[int] = None, linear_num_k_heads: Optional[int] = None, linear_num_v_heads: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, num_initial_tokens: int = 0, ): """ @@ -59,7 +62,6 @@ def __init__( num_initial_tokens: The number of tokens to initialize the cache with. """ self.num_kv_heads = num_kv_heads - self.head_dim = head_dim self.dtype = dtype self.block_size = block_size self.conv_dim = conv_dim @@ -68,11 +70,18 @@ def __init__( self.linear_v_dim = linear_v_dim self.linear_num_k_heads = linear_num_k_heads self.linear_num_v_heads = linear_num_v_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.head_dim_v = head_dim_v + self.head_dim_k = head_dim_k num_initial_tokens = self.round_up_to_step(num_initial_tokens) # (num_layers, num_kv_heads, seq_len, head_dim) - self.keys = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, head_dim), dtype) - self.values = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, head_dim), dtype) + + self.keys = mx.zeros((num_layers, num_kv_heads, num_initial_tokens, self.head_dim_k), dtype) + self.values = mx.zeros( + (num_layers, num_kv_heads, num_initial_tokens, self.head_dim_v), dtype + ) self.state0 = ( mx.zeros((num_layers, conv_kernel_size - 1, conv_dim), dtype) if conv_dim else None ) @@ -115,8 +124,8 @@ def update( Updates the cache with new key-value pairs. Args: - keys: New keys to add, shape (num_layers, num_kv_heads, target_len, head_dim) - values: New values to add, shape (num_layers, num_kv_heads, target_len, head_dim) + keys: New keys to add, shape (num_layers, num_kv_heads, target_len, head_dim_k) + values: New values to add, shape (num_layers, num_kv_heads, target_len, head_dim_v) """ if state0 is not None and self.state0 is not None: self.state0 = state0 @@ -128,10 +137,11 @@ def update( prev_tokens = self.num_tokens # Grow the cache based on the block_size size if self.needs_grow(seq_len): - num_layers, num_kv_heads, _, head_dim = keys.shape + num_layers, num_kv_heads, _, head_dim_k = keys.shape + _, _, _, head_dim_v = values.shape n_steps = (self.block_size + seq_len - 1) // self.block_size - k_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim) - v_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim) + k_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_k) + v_shape = (num_layers, num_kv_heads, n_steps * self.block_size, head_dim_v) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) @@ -167,6 +177,8 @@ def __init__( linear_v_dim: Optional[int] = None, linear_num_k_heads: Optional[int] = None, linear_num_v_heads: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, ): """ Args: @@ -179,7 +191,6 @@ def __init__( cache_memory_fraction: The fraction of memory to use for the cache. """ self.num_kv_heads = num_kv_heads - self.head_dim = head_dim self.num_layers = num_layers self.dtype = dtype self.block_size = block_size @@ -189,6 +200,13 @@ def __init__( self.linear_v_dim = linear_v_dim self.linear_num_k_heads = linear_num_k_heads self.linear_num_v_heads = linear_num_v_heads + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + if qk_nope_head_dim and qk_rope_head_dim: + self.head_dim_k = qk_nope_head_dim + qk_rope_head_dim + else: + self.head_dim_k = head_dim + self.head_dim_v = head_dim self.request_caches: Dict[str, KVCache] = {} self.tokens_in_cache = 0 @@ -198,7 +216,8 @@ def __init__( kv_cache_memory_fraction=cache_memory_fraction, num_shard_layers=num_layers, num_key_value_heads=num_kv_heads, - head_dim=head_dim, + head_dim_k=self.head_dim_k, + head_dim_v=self.head_dim_v, elem_bytes=dtype.size, ) if max_num_tokens is not None: @@ -264,7 +283,8 @@ def add_request(self, request: Request, num_tokens: int = 128) -> bool: self.request_caches[request.request_id] = KVCache( num_kv_heads=self.num_kv_heads, - head_dim=self.head_dim, + head_dim_k=self.head_dim_k, + head_dim_v=self.head_dim_v, num_layers=self.num_layers, dtype=self.dtype, block_size=self.block_size, @@ -311,16 +331,18 @@ def update_requests( Returns: True if requests are updated. """ - batch_size, num_layers, n_kv_heads, _, head_dim = keys.shape + batch_size, num_layers, n_kv_heads, _, head_dim_k = keys.shape + _, _, _, _, head_dim_v = values.shape # Validate - assert keys.shape == values.shape, "key and value must have the same shape" + # assert keys.shape == values.shape, "key and value must have the same shape" assert num_layers == self.num_layers, "key and value must have the same number of layers" assert batch_size == len(requests), "key and value must have the same batch size" assert len(lengths) == batch_size, "lengths must have the same batch size as requests" assert ( n_kv_heads == self.num_kv_heads ), "key and value must have the same number of key-value heads" - assert head_dim == self.head_dim, "key and value must have the same head dimension" + assert head_dim_k == self.head_dim_k, "key and value must have the same head dimension" + assert head_dim_v == self.head_dim_v, "key and value must have the same head dimension" # TODO: Use vmap for better performance for request, key, value, length, state0, state1 in zip( requests, keys, values, lengths, states0, states1 diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 7b4f12b6..f15250d6 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -10,10 +10,10 @@ import mlx.core as mx import safetensors from mlx import nn -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_config from parallax.server.model import ShardedModel +from parallax.utils.tokenizer_utils import load_tokenizer from parallax_utils.logging_config import get_logger logger = get_logger(__name__) @@ -115,6 +115,8 @@ def load( # We need the model object to know its structure and which layers it owns. # This part mirrors the logic from the provided utils.py to get model_args. model_type = config.get("model_type") + if model_type == "kimi_k2": + model_type = "deepseek_v3" if not model_type: raise ValueError("model_type not found in config.json") try: diff --git a/src/parallax/sglang/model_runner.py b/src/parallax/sglang/model_runner.py index b043bdef..63df8ab0 100755 --- a/src/parallax/sglang/model_runner.py +++ b/src/parallax/sglang/model_runner.py @@ -13,7 +13,6 @@ import sglang import sglang.srt.distributed.parallel_state import torch -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import ( @@ -44,6 +43,8 @@ ) from torch.distributed import Backend +from parallax.utils.tokenizer_utils import load_tokenizer + # from parallax.sglang.monkey_patch.model_runner import ModelRunner as SGLModelRunner logger = logging.getLogger(__name__) diff --git a/src/parallax/utils/tokenizer_utils.py b/src/parallax/utils/tokenizer_utils.py index 463ef59d..faefef22 100755 --- a/src/parallax/utils/tokenizer_utils.py +++ b/src/parallax/utils/tokenizer_utils.py @@ -14,6 +14,7 @@ _is_spm_decoder, _is_spm_decoder_no_space, ) +from mlx_lm.tokenizer_utils import load_tokenizer as _mlx_load_tokenizer class ParallaxNaiveStreamingDetokenizer(NaiveStreamingDetokenizer): @@ -97,3 +98,28 @@ def load_detokenizer(model_path, tokenizer): tokenmap = _get_bpe_tokenmap(tokenizer) return detokenizer_class, tokenmap + + +def load_tokenizer(model_path, trust_remote_code=True, tokenizer_config_extra=None, **kwargs): + """ + Wrapper function for MLX load_tokenizer that defaults trust_remote_code to True. + This is needed for models like Kimi-K2 that contain custom code. + + Args: + model_path: Path to the model + trust_remote_code: Whether to trust remote code (defaults to True) + tokenizer_config_extra: Extra config to pass to AutoTokenizer.from_pretrained + **kwargs: Additional arguments to pass to the original load_tokenizer + + Returns: + The loaded tokenizer + """ + if tokenizer_config_extra is None: + tokenizer_config_extra = {} + + # Add trust_remote_code to the tokenizer config + if trust_remote_code: + tokenizer_config_extra = tokenizer_config_extra.copy() + tokenizer_config_extra["trust_remote_code"] = True + + return _mlx_load_tokenizer(model_path, tokenizer_config_extra=tokenizer_config_extra, **kwargs) diff --git a/src/parallax_utils/utils.py b/src/parallax_utils/utils.py index 53049c2f..ea712654 100644 --- a/src/parallax_utils/utils.py +++ b/src/parallax_utils/utils.py @@ -47,7 +47,8 @@ def compute_max_tokens_in_cache( kv_cache_memory_fraction: float, num_shard_layers: int, num_key_value_heads: int, - head_dim: int, + head_dim_k: int, + head_dim_v: int, elem_bytes: int, available_cache_bytes: Optional[int] = None, ) -> int: @@ -65,7 +66,9 @@ def compute_max_tokens_in_cache( hw = HardwareInfo.detect() used = mx.get_active_memory() if mx is not None else 0 available_cache_size = int((hw.total_ram_gb * 1024**3 - used) * kv_cache_memory_fraction) - per_token_cache_size = num_shard_layers * num_key_value_heads * head_dim * 2 * elem_bytes + per_token_cache_size = ( + num_shard_layers * num_key_value_heads * (head_dim_k + head_dim_v) * elem_bytes + ) return max(0, available_cache_size // per_token_cache_size) @@ -101,6 +104,8 @@ def compute_max_batch_size( dtype=None, elem_bytes: Optional[int] = None, memory_gb: Optional[float] = None, + head_dim_k: Optional[int] = None, + head_dim_v: Optional[int] = None, ) -> int: """Compute final max_batch_size by chaining dtype->elem_bytes, KV capacity, and clamping. @@ -110,12 +115,14 @@ def compute_max_batch_size( available_cache_bytes = None if memory_gb is not None: available_cache_bytes = int(memory_gb * 1024**3 * kv_cache_memory_fraction) + ## This is an Error due to kv may have different head_dim max_tokens = compute_max_tokens_in_cache( device=device or "", # empty means non-cuda path kv_cache_memory_fraction=kv_cache_memory_fraction, num_shard_layers=num_shard_layers, num_key_value_heads=num_key_value_heads, - head_dim=head_dim, + head_dim_k=head_dim_k if head_dim_k is not None else head_dim, + head_dim_v=head_dim_v if head_dim_v is not None else head_dim, elem_bytes=eb, available_cache_bytes=available_cache_bytes, ) diff --git a/src/scheduling/model_info.py b/src/scheduling/model_info.py index a6c14680..bd49e060 100644 --- a/src/scheduling/model_info.py +++ b/src/scheduling/model_info.py @@ -35,10 +35,30 @@ class ModelInfo: cache_bytes_per_element: int = 1 embedding_bytes_per_element: int = 1 + qk_nope_head_dim: Optional[int] = None + qk_rope_head_dim: Optional[int] = None + head_size_k: int = None + head_size_v: int = None + + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + if self.qk_nope_head_dim is not None and self.qk_rope_head_dim is not None: + self.head_size_k = self.qk_nope_head_dim + self.qk_rope_head_dim + else: + self.head_size_k = self.head_size + self.head_size_v = self.head_size + @property - def kv_dim(self) -> int: + def v_dim(self) -> int: """Return key and value head dim.""" - return self.num_kv_heads * self.head_size + return self.num_kv_heads * self.head_size_v + + @property + def k_dim(self) -> int: + """Return key head dim.""" + return self.num_kv_heads * self.head_size_k @property def embedding_io_bytes(self) -> int: @@ -48,7 +68,7 @@ def embedding_io_bytes(self) -> int: @property def per_token_per_layer_kv_size(self) -> int: """Return bytes per token for KV cache.""" - return 2 * self.cache_bytes_per_element * self.kv_dim + return self.cache_bytes_per_element * (self.k_dim + self.v_dim) def per_layer_kv_cache_size(self, *, batch_size: int = 1, source_seq_len: int = 256) -> int: """Return size of KV cache in bytes for given request dimensions.""" @@ -81,7 +101,7 @@ def decoder_layer_flops( # Q/O projections: (T, hidden_dim) @ (hidden_dim, hidden_dim) qo_flops = 2 * 2 * target_seq_len * self.hidden_dim * self.hidden_dim # K/V projections: (T, hidden_dim) @ (hidden_dim, kv_dim) - kv_flops = 2 * 2 * target_seq_len * self.hidden_dim * self.kv_dim + kv_flops = 2 * target_seq_len * self.hidden_dim * (self.k_dim + self.v_dim) projection_flops = qo_flops + kv_flops # 'roof' estimation for GQA @@ -124,7 +144,7 @@ def decoder_layer_io_bytes( """ # Attention params qo_params = self.param_bytes_per_element * self.hidden_dim * self.hidden_dim - kv_params = self.param_bytes_per_element * self.hidden_dim * self.kv_dim + kv_params = self.param_bytes_per_element * self.hidden_dim * (self.k_dim + self.v_dim) // 2 attention_params = qo_params + kv_params # FFN params diff --git a/src/scheduling/node.py b/src/scheduling/node.py index 13a7a9f3..6e40b7eb 100644 --- a/src/scheduling/node.py +++ b/src/scheduling/node.py @@ -222,6 +222,8 @@ def max_requests(self) -> int: head_dim=self.model_info.head_size, elem_bytes=elem_bytes, memory_gb=self.hardware.memory_gb, + head_dim_k=self.model_info.head_size_k, + head_dim_v=self.model_info.head_size_v, ) if derived_max <= 0: raise ValueError( diff --git a/tests/test_executor.py b/tests/test_executor.py index fe66c18d..1dd16093 100644 --- a/tests/test_executor.py +++ b/tests/test_executor.py @@ -4,11 +4,11 @@ import pytest from mlx_lm.generate import generate -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_model from parallax.server.executor import Executor from parallax.server.request import InitialRequest +from parallax.utils.tokenizer_utils import load_tokenizer MODEL_REPO = "mlx-community/Qwen3-0.6B-bf16" diff --git a/tests/test_model.py b/tests/test_model.py index c839e91d..a62dd2db 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -7,11 +7,11 @@ import mlx.core as mx import pytest from mlx_lm.models.base import create_attention_mask -from mlx_lm.tokenizer_utils import load_tokenizer from mlx_lm.utils import get_model_path, load_model from parallax.server.server_info import ShardedModelInfo from parallax.server.shard_loader import MLXModelLoader +from parallax.utils.tokenizer_utils import load_tokenizer from parallax.utils.utils import pad_inputs REPO_ID = "mlx-community/Qwen3-0.6B-bf16"