Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
# Supported model list
MODEL_LIST = [
"Qwen/Qwen3-0.6B",
"openai/gpt-oss-20b",
"openai/gpt-oss-120b",
"moonshotai/Kimi-K2-Instruct",
"moonshotai/Kimi-K2-Instruct-0905",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
# "Qwen/Qwen3-8B",
# "Qwen/Qwen3-8B-FP8",
"Qwen/Qwen3-32B",
Expand All @@ -16,14 +22,10 @@
# "Qwen/Qwen3-30B-A3B-Thinking-2507-FP8",
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8",
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8",
"Qwen/Qwen3-Next-80B-A3B-Instruct",
"Qwen/Qwen3-Next-80B-A3B-Thinking",
# "Qwen/Qwen2.5-3B-Instruct",
# "Qwen/Qwen2.5-7B-Instruct",
# "Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-72B-Instruct",
"openai/gpt-oss-20b",
"openai/gpt-oss-120b",
"nvidia/Llama-3.3-70B-Instruct-FP8",
"nvidia/Llama-3.1-70B-Instruct-FP8",
"nvidia/Llama-3.1-8B-Instruct-FP8",
Expand Down Expand Up @@ -56,6 +58,8 @@ def get_model_info(model_name):
model_info = ModelInfo(
model_name=model_name,
head_size=config.get("head_dim", 128),
qk_nope_head_dim=config.get("qk_nope_head_dim", None),
qk_rope_head_dim=config.get("qk_rope_head_dim", None),
hidden_dim=config.get("hidden_size", 0),
intermediate_dim=config.get("intermediate_size", 0),
num_attention_heads=config.get("num_attention_heads", 0),
Expand Down
130 changes: 130 additions & 0 deletions src/parallax/models/deepseek_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
hidden_dimefines the Qwen3 model.
"""

from typing import Optional, Tuple

import mlx.core as mx
from mlx_lm.models.base import scaled_dot_product_attention
from mlx_lm.models.deepseek_v2 import DeepseekV2Attention as MLXDeepseekV2Attention
from mlx_lm.models.deepseek_v2 import DeepseekV2DecoderLayer as MLXDeepseekV2Block
from mlx_lm.models.deepseek_v2 import ModelArgs


class ParallaxDeepSeekV2Attention(MLXDeepseekV2Attention):
"""A custom attention module for Parallax, extending the DeepseekV2 Attention class.

We apply explicit KV cache handling and passing in `offset` directly from Request.
This version returns the new K and V states for external caching.
"""

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
"""
Attention forward pass with explicit KV cache handling.

Args:
x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment.
mask: (batch, n_q_heads, target_len, source_len)
cache: Optional tuple (past_k, past_v).
shape: (batch, n_kv_heads, S_past_padded, head_dim)
offset: source_len_padded (scalar, used for RoPE calculation).

Returns:
output_h: (batch, target_len, hidden_dim) - Output hidden states.
new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment.
new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment.
"""
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))

q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)

k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)

q_pe = self.rope(q_pe, offset=offset)
k_pe = self.rope(k_pe, offset=offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)

if cache is not None:
past_k, past_v = cache
if past_k is not None and past_v is not None:
if past_k.shape[2] != offset:
raise ValueError(
f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} "
f"to match RoPE offset {offset} (S_past_padded)."
)
final_keys_for_attn = mx.concatenate(
[past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2
)
final_values_for_attn = mx.concatenate([past_v, values], axis=2)
else:
raise ValueError("cache was provided but one of k/v was None.")
else:
final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1)
final_values_for_attn = values

if mask is not None:
mask = mx.array(mask, dtype=queries.dtype) # Ensure mask is the same dtype as queries
output = scaled_dot_product_attention(
queries,
final_keys_for_attn,
final_values_for_attn,
scale=self.scale,
mask=mask,
cache=None,
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)

# print(f"Output values shape: {values.shape}")
# print(f"Output k_nope shape: {(mx.concatenate([k_nope, k_pe], axis=-1)).shape}")
return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values)


class ParallaxDeepSeekV2Block(MLXDeepseekV2Block):
"""A custom transformer block for Parallax, extending the Qwen3 Block class.
This version handles the KV cache explicitly and returns new K and V states.
"""

def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__(args, layer_idx=layer_idx)
self.self_attn = ParallaxDeepSeekV2Attention(args)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
):
r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, (k_cache, v_cache)

@classmethod
def get_architecture(cls):
"""Get the architecture name for the block."""
return "DeepseekV2ForCausalLM"


EntryClass = ParallaxDeepSeekV2Block
127 changes: 127 additions & 0 deletions src/parallax/models/kimi_k2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
hidden_dimefines the Qwen3 model.
"""

from typing import Optional, Tuple

import mlx.core as mx
from mlx_lm.models.base import scaled_dot_product_attention
from mlx_lm.models.deepseek_v3 import DeepseekV3Attention as MLXDeepseekV3Attention
from mlx_lm.models.deepseek_v3 import DeepseekV3DecoderLayer as MLXDeepseekV3Block
from mlx_lm.models.deepseek_v3 import ModelArgs


class ParallaxKimiK2Attention(MLXDeepseekV3Attention):
"""A custom attention module for Parallax, extending the DeepseekV3 Attention class.

We apply explicit KV cache handling and passing in `offset` directly from Request.
This version returns the new K and V states for external caching.
"""

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
"""
Attention forward pass with explicit KV cache handling.

Args:
x: (batch, target_len, hidden_dim) - Input hidden states for the current query segment.
mask: (batch, n_q_heads, target_len, source_len)
cache: Optional tuple (past_k, past_v).
shape: (batch, n_kv_heads, S_past_padded, head_dim)
offset: source_len_padded (scalar, used for RoPE calculation).

Returns:
output_h: (batch, target_len, hidden_dim) - Output hidden states.
new_k: (batch, n_kv_heads, target_len, head_dim) - New keys for this segment.
new_v: (batch, n_kv_heads, target_len, head_dim) - New values for this segment.
"""
B, L, D = x.shape
if self.q_lora_rank is None:
q = self.q_proj(x)
else:
q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(x)))

q = q.reshape(B, L, self.num_heads, self.q_head_dim).transpose(0, 2, 1, 3)
q_nope, q_pe = mx.split(q, [self.qk_nope_head_dim], axis=-1)
compressed_kv = self.kv_a_proj_with_mqa(x)
compressed_kv, k_pe = mx.split(compressed_kv, [self.kv_lora_rank], axis=-1)
k_pe = k_pe.reshape(B, L, 1, self.qk_rope_head_dim).transpose(0, 2, 1, 3)
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv))
kv = kv.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)

k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)

q_pe = self.rope(q_pe, offset=offset)
k_pe = self.rope(k_pe, offset=offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)

if cache is not None:
past_k, past_v = cache
if past_k is not None and past_v is not None:
if past_k.shape[2] != offset:
raise ValueError(
f"ParallaxAttention: Expected past_k sequence length {past_k.shape[2]} "
f"to match RoPE offset {offset} (S_past_padded)."
)
final_keys_for_attn = mx.concatenate(
[past_k, mx.concatenate([k_nope, k_pe], axis=-1)], axis=2
)
final_values_for_attn = mx.concatenate([past_v, values], axis=2)
else:
raise ValueError("cache was provided but one of k/v was None.")
else:
final_keys_for_attn = mx.concatenate([k_nope, k_pe], axis=-1)
final_values_for_attn = values

if mask is not None:
mask = mx.array(mask, dtype=queries.dtype)
output = scaled_dot_product_attention(
queries,
final_keys_for_attn,
final_values_for_attn,
scale=self.scale,
mask=mask,
cache=None,
)

output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.o_proj(output), (mx.concatenate([k_nope, k_pe], axis=-1), values)


class ParallaxKimiK2Block(MLXDeepseekV3Block):
"""A custom transformer block for Parallax, extending the Qwen3 Block class.
This version handles the KV cache explicitly and returns new K and V states.
"""

def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__(args, layer_idx=layer_idx)
self.self_attn = ParallaxKimiK2Attention(args)

def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
offset: int = 0,
lengths: Optional[mx.array] = None,
):
r, (k_cache, v_cache) = self.self_attn(self.input_layernorm(x), mask, cache, offset=offset)
h = x + r
r = self.mlp(self.post_attention_layernorm(h))
out = h + r
return out, (k_cache, v_cache)

@classmethod
def get_architecture(cls):
"""Get the architecture name for the block."""
return "DeepseekV3ForCausalLM"


EntryClass = ParallaxKimiK2Block
4 changes: 4 additions & 0 deletions src/parallax/server/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(
self.head_dim = self.config.get("head_dim") or self.config.get(
"hidden_size"
) // self.config.get("num_attention_heads")
self.qk_nope_head_dim = self.config.get("qk_nope_head_dim", None)
self.qk_rope_head_dim = self.config.get("qk_rope_head_dim", None)
self.enable_prefix_cache = enable_prefix_cache
self.linear_key_head_dim = self.config.get("linear_key_head_dim", None)
self.linear_value_head_dim = self.config.get("linear_value_head_dim", None)
Expand Down Expand Up @@ -209,6 +211,8 @@ def __init__(
linear_v_dim=self.linear_value_head_dim,
linear_num_k_heads=self.linear_num_key_heads,
linear_num_v_heads=self.linear_num_value_heads,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
max_num_tokens=max_tokens_in_kv_pool,
)
mx.set_wired_limit(mx.metal.device_info()["max_recommended_working_set_size"])
Expand Down
4 changes: 2 additions & 2 deletions src/parallax/server/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@
import zmq
import zmq.asyncio
from fastapi.responses import ORJSONResponse, StreamingResponse
from mlx_lm.tokenizer_utils import StreamingDetokenizer, load_tokenizer
from mlx_lm.tokenizer_utils import StreamingDetokenizer
from mlx_lm.utils import get_model_path, load_config
from pydantic import BaseModel
from starlette.datastructures import State

from parallax.utils.tokenizer_utils import load_detokenizer
from parallax.utils.tokenizer_utils import load_detokenizer, load_tokenizer
from parallax.utils.utils import get_zmq_socket
from parallax_utils.logging_config import get_logger

Expand Down
Loading